diff --git a/LICENSE b/LICENSE index 2071b23..261eeb9 100644 --- a/LICENSE +++ b/LICENSE @@ -1,9 +1,201 @@ -MIT License + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ -Copyright (c) + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + 1. Definitions. -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 61da0df..b707509 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,803 @@ -# jstarcraft-rns-master +**基于JStarCraft实现的搜索引擎** + +## 1.项目介绍 + +**JStarCraft RNS是一个面向信息检索领域的轻量级引擎.遵循Apache 2.0协议.** + +专注于解决信息检索领域的基本问题:推荐与搜索. + +提供满足工业级别场景要求的推荐引擎设计与实现. + +提供满足工业级别场景要求的搜索引擎设计与实现. + +**** + +## 2.特性 + +* 1.跨平台 +* [2.串行与并行计算](https://github.com/HongZhaoHua/jstarcraft-ai) +* [3.CPU与GPU硬件加速](https://github.com/HongZhaoHua/jstarcraft-ai) +* [4.模型保存与装载](https://github.com/HongZhaoHua/jstarcraft-ai) +* 5.丰富的推荐与搜索算法 +* 6.丰富的脚本支持 + * Groovy + * JS + * Lua + * MVEL + * Python + * Ruby +* [7.丰富的评估指标](#评估指标) + * [排序指标](#排序指标) + * [评分指标](#评分指标) + +**** + +## 3.安装 + +JStarCraft RNS要求使用者具备以下环境: +* JDK 8或者以上 +* Maven 3 + +#### 3.1安装JStarCraft-Core框架 + +```shell +git clone https://github.com/HongZhaoHua/jstarcraft-core.git + +mvn install -Dmaven.test.skip=true +``` + +#### 3.2安装JStarCraft-AI框架 + +```shell +git clone https://github.com/HongZhaoHua/jstarcraft-ai.git + +mvn install -Dmaven.test.skip=true +``` + +#### 3.3安装JStarCraft-RNS引擎 + +```shell +git clone https://github.com/HongZhaoHua/jstarcraft-rns.git + +mvn install -Dmaven.test.skip=true +``` + +**** + +## 4.使用 + +#### 4.1设置依赖 + +* 设置Maven依赖 + +```xml + + com.jstarcraft + rns + 1.0 + +``` + +* 设置Gradle依赖 + +```gradle +compile group: 'com.jstarcraft', name: 'rns', version: '1.0' +``` + +#### 4.2构建配置器 + +```java +Properties keyValues = new Properties(); +keyValues.load(this.getClass().getResourceAsStream("/data.properties")); +keyValues.load(this.getClass().getResourceAsStream("/recommend/benchmark/randomguess-test.properties")); +Configurator configurator = new Configurator(keyValues); +``` + +#### 4.3训练与评估模型 + +* 构建排序任务 + +```java +RankingTask task = new RankingTask(RandomGuessModel.class, configurator); +// 训练与评估排序模型 +task.execute(); +``` + +* 构建评分任务 + +```java +RatingTask task = new RatingTask(RandomGuessModel.class, configurator); +// 训练与评估评分模型 +task.execute(); +``` + +#### 4.4获取模型 + +```java +// 获取模型 +Model model = task.getModel(); +``` + +**** + +## 5.概念 + +#### 5.1为什么需要信息检索 + +``` +随着信息技术和互联网的发展,人们逐渐从信息匮乏(Information Underload)的时代走入了信息过载(Information Overload)的时代. + +无论是信息消费者还是信息生产者都遇到了挑战: +* 对于信息消费者,从海量信息中寻找信息,是一件非常困难的事情; +* 对于信息生产者,从海量信息中暴露信息,也是一件非常困难的事情; + +信息检索的任务就是联系用户和信息,一方面帮助用户寻找对自己有价值的信息,另一方面帮助信息暴露给对它感兴趣的用户,从而实现信息消费者和信息生产者的双赢. +``` + +#### 5.2搜索与推荐的异同 + +``` +从信息检索的角度: +* 搜索和推荐是获取信息的两种主要手段; +* 搜索和推荐是获取信息的两种不同方式; + * 搜索(Search)是主动明确的; + * 推荐(Recommend)是被动模糊的; + +搜索和推荐是两个互补的工具. +``` + +#### 5.3JStarCraft-RNS引擎解决什么问题 + +``` +JStarCraft-RNS引擎旨在解决推荐与搜索领域的两个核心任务:排序预测(Ranking)和评分预测(Rating). +``` + +#### 5.4Ranking任务与Rating任务之间的区别 + +``` +根据解决基本问题的不同,将算法与评估指标划分为排序(Ranking)与评分(Rating). + +两者之间的根本区别在于目标函数的不同. +通俗点的解释: +Ranking算法基于隐式反馈数据,趋向于拟合用户的排序.(关注度) +Rating算法基于显示反馈数据,趋向于拟合用户的评分.(满意度) +``` + +#### 5.5Rating算法能不能用于Ranking问题 + +``` +关键在于具体场景中,关注度与满意度是否保持一致. +通俗点的解释: +人们关注的东西,并不一定是满意的东西.(例如:个人所得税) +``` + +**** + +## 6.示例 + +#### 6.1JStarCraft-RNS引擎与BeanShell脚本交互 + +* [完整示例](https://github.com/HongZhaoHua/jstarcraft-rns/tree/master/src/test/java/com/jstarcraft/rns/script) + +* 编写BeanShell脚本训练与评估模型并保存到Model.bsh文件 + +```java +// 构建配置 +keyValues = new Properties(); +keyValues.load(loader.getResourceAsStream("data.properties")); +keyValues.load(loader.getResourceAsStream("model/benchmark/randomguess-test.properties")); +configurator = new Configurator(keyValues); + +// 此对象会返回给Java程序 +_data = new HashMap(); + +// 构建排序任务 +task = new RankingTask(RandomGuessModel.class, configurator); +// 训练与评估模型并获取排序指标 +measures = task.execute(); +_data.put("precision", measures.get(PrecisionEvaluator.class)); +_data.put("recall", measures.get(RecallEvaluator.class)); + +// 构建评分任务 +task = new RatingTask(RandomGuessModel.class, configurator); +// 训练与评估模型并获取评分指标 +measures = task.execute(); +_data.put("mae", measures.get(MAEEvaluator.class)); +_data.put("mse", measures.get(MSEEvaluator.class)); + +_data; +``` + +* 使用JStarCraft框架从Model.bsh文件加载并执行BeanShell脚本 + +```java + // 获取BeanShell脚本 +File file = new File(ScriptTestCase.class.getResource("Model.bsh").toURI()); +String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + +// 设置BeanShell脚本使用到的Java类 +ScriptContext context = new ScriptContext(); +context.useClasses(Properties.class, Assert.class); +context.useClass("Configurator", MapConfigurator.class); +context.useClasses("com.jstarcraft.ai.evaluate"); +context.useClasses("com.jstarcraft.rns.task"); +context.useClasses("com.jstarcraft.rns.model.benchmark"); +// 设置BeanShell脚本使用到的Java变量 +ScriptScope scope = new ScriptScope(); +scope.createAttribute("loader", loader); + +// 执行BeanShell脚本 +ScriptExpression expression = new GroovyExpression(context, scope, script); +Map data = expression.doWith(Map.class); +Assert.assertEquals(0.005825241F, data.get("precision"), 0F); +Assert.assertEquals(0.011579763F, data.get("recall"), 0F); +Assert.assertEquals(1.2708743F, data.get("mae"), 0F); +Assert.assertEquals(2.425075F, data.get("mse"), 0F); +``` + +#### 6.2JStarCraft-RNS引擎与Groovy脚本交互 + +* [完整示例](https://github.com/HongZhaoHua/jstarcraft-rns/tree/master/src/test/java/com/jstarcraft/rns/script) + +* 编写Groovy脚本训练与评估模型并保存到Model.groovy文件 + +```groovy +// 构建配置 +def keyValues = new Properties(); +keyValues.load(loader.getResourceAsStream("data.properties")); +keyValues.load(loader.getResourceAsStream("recommend/benchmark/randomguess-test.properties")); +def configurator = new Configurator(keyValues); + +// 此对象会返回给Java程序 +def _data = [:]; + +// 构建排序任务 +task = new RankingTask(RandomGuessModel.class, configurator); +// 训练与评估模型并获取排序指标 +measures = task.execute(); +_data.precision = measures.get(PrecisionEvaluator.class); +_data.recall = measures.get(RecallEvaluator.class); + +// 构建评分任务 +task = new RatingTask(RandomGuessModel.class, configurator); +// 训练与评估模型并获取评分指标 +measures = task.execute(); +_data.mae = measures.get(MAEEvaluator.class); +_data.mse = measures.get(MSEEvaluator.class); + +_data; +``` + +* 使用JStarCraft框架从Model.groovy文件加载并执行Groovy脚本 + +```java +// 获取Groovy脚本 +File file = new File(ScriptTestCase.class.getResource("Model.groovy").toURI()); +String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + +// 设置Groovy脚本使用到的Java类 +ScriptContext context = new ScriptContext(); +context.useClasses(Properties.class, Assert.class); +context.useClass("Configurator", MapConfigurator.class); +context.useClasses("com.jstarcraft.ai.evaluate"); +context.useClasses("com.jstarcraft.rns.task"); +context.useClasses("com.jstarcraft.rns.model.benchmark"); +// 设置Groovy脚本使用到的Java变量 +ScriptScope scope = new ScriptScope(); +scope.createAttribute("loader", loader); + +// 执行Groovy脚本 +ScriptExpression expression = new GroovyExpression(context, scope, script); +Map data = expression.doWith(Map.class); +``` + +#### 6.3JStarCraft-RNS引擎与JS脚本交互 + +* [完整示例](https://github.com/HongZhaoHua/jstarcraft-rns/tree/master/src/test/java/com/jstarcraft/rns/script) + +* 编写JS脚本训练与评估模型并保存到Model.js文件 + +```js +// 构建配置 +var keyValues = new Properties(); +keyValues.load(loader.getResourceAsStream("data.properties")); +keyValues.load(loader.getResourceAsStream("recommend/benchmark/randomguess-test.properties")); +var configurator = new Configurator([keyValues]); + +// 此对象会返回给Java程序 +var _data = {}; + +// 构建排序任务 +task = new RankingTask(RandomGuessModel.class, configurator); +// 训练与评估模型并获取排序指标 +measures = task.execute(); +_data['precision'] = measures.get(PrecisionEvaluator.class); +_data['recall'] = measures.get(RecallEvaluator.class); + +// 构建评分任务 +task = new RatingTask(RandomGuessModel.class, configurator); +// 训练与评估模型并获取评分指标 +measures = task.execute(); +_data['mae'] = measures.get(MAEEvaluator.class); +_data['mse'] = measures.get(MSEEvaluator.class); + +_data; +``` + +* 使用JStarCraft框架从Model.js文件加载并执行JS脚本 + +```java +// 获取JS脚本 +File file = new File(ScriptTestCase.class.getResource("Model.js").toURI()); +String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + +// 设置JS脚本使用到的Java类 +ScriptContext context = new ScriptContext(); +context.useClasses(Properties.class, Assert.class); +context.useClass("Configurator", MapConfigurator.class); +context.useClasses("com.jstarcraft.ai.evaluate"); +context.useClasses("com.jstarcraft.rns.task"); +context.useClasses("com.jstarcraft.rns.model.benchmark"); +// 设置JS脚本使用到的Java变量 +ScriptScope scope = new ScriptScope(); +scope.createAttribute("loader", loader); + +// 执行JS脚本 +ScriptExpression expression = new JsExpression(context, scope, script); +Map data = expression.doWith(Map.class); +``` + +#### 6.4JStarCraft-RNS引擎与Kotlin脚本交互 + +* [完整示例](https://github.com/HongZhaoHua/jstarcraft-rns/tree/master/src/test/java/com/jstarcraft/rns/script) + +* 编写Kotlin脚本训练与评估模型并保存到Model.kt文件 + +```js +// 构建配置 +var keyValues = Properties(); +var loader = bindings["loader"] as ClassLoader; +keyValues.load(loader.getResourceAsStream("data.properties")); +keyValues.load(loader.getResourceAsStream("model/benchmark/randomguess-test.properties")); +var option = Option(keyValues); + +// 此对象会返回给Java程序 +var _data = mutableMapOf(); + +// 构建排序任务 +var rankingTask = RankingTask(RandomGuessModel::class.java, option); +// 训练与评估模型并获取排序指标 +val rankingMeasures = rankingTask.execute(); +_data["precision"] = rankingMeasures.getFloat(PrecisionEvaluator::class.java); +_data["recall"] = rankingMeasures.getFloat(RecallEvaluator::class.java); + +// 构建评分任务 +var ratingTask = RatingTask(RandomGuessModel::class.java, option); +// 训练与评估模型并获取评分指标 +var ratingMeasures = ratingTask.execute(); +_data["mae"] = ratingMeasures.getFloat(MAEEvaluator::class.java); +_data["mse"] = ratingMeasures.getFloat(MSEEvaluator::class.java); + +_data; +``` + +* 使用JStarCraft框架从Model.kt文件加载并执行Kotlin脚本 + +```java +// 获取Kotlin脚本 +File file = new File(ScriptTestCase.class.getResource("Model.kt").toURI()); +String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + +// 设置Kotlin脚本使用到的Java类 +ScriptContext context = new ScriptContext(); +context.useClasses(Properties.class, Assert.class); +context.useClass("Option", MapOption.class); +context.useClasses("com.jstarcraft.ai.evaluate"); +context.useClasses("com.jstarcraft.rns.task"); +context.useClasses("com.jstarcraft.rns.model.benchmark"); +// 设置Kotlin脚本使用到的Java变量 +ScriptScope scope = new ScriptScope(); +scope.createAttribute("loader", loader); + +// 执行Kotlin脚本 +ScriptExpression expression = new KotlinExpression(context, scope, script); +Map data = expression.doWith(Map.class); +``` + +#### 6.5JStarCraft-RNS引擎与Lua脚本交互 + +* [完整示例](https://github.com/HongZhaoHua/jstarcraft-rns/tree/master/src/test/java/com/jstarcraft/rns/script) + +* 编写Lua脚本训练与评估模型并保存到Model.lua文件 + +```lua +-- 构建配置 +local keyValues = Properties.new(); +keyValues:load(loader:getResourceAsStream("data.properties")); + +keyValues:load(loader:getResourceAsStream("recommend/benchmark/randomguess-test.properties")); +local configurator = Configurator.new({ keyValues }); + +-- 此对象会返回给Java程序 +local _data = {}; + +-- 构建排序任务 +task = RankingTask.new(RandomGuessModel, configurator); +-- 训练与评估模型并获取排序指标 +measures = task:execute(); +_data["precision"] = measures:get(PrecisionEvaluator); +_data["recall"] = measures:get(RecallEvaluator); + +-- 构建评分任务 +task = RatingTask.new(RandomGuessModel, configurator); +-- 训练与评估模型并获取评分指标 +measures = task:execute(); +_data["mae"] = measures:get(MAEEvaluator); +_data["mse"] = measures:get(MSEEvaluator); + +return _data; +``` + +* 使用JStarCraft框架从Model.lua文件加载并执行Lua脚本 + +```java +// 获取Lua脚本 +File file = new File(ScriptTestCase.class.getResource("Model.lua").toURI()); +String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + +// 设置Lua脚本使用到的Java类 +ScriptContext context = new ScriptContext(); +context.useClasses(Properties.class, Assert.class); +context.useClass("Configurator", MapConfigurator.class); +context.useClasses("com.jstarcraft.ai.evaluate"); +context.useClasses("com.jstarcraft.rns.task"); +context.useClasses("com.jstarcraft.rns.model.benchmark"); +// 设置Lua脚本使用到的Java变量 +ScriptScope scope = new ScriptScope(); +scope.createAttribute("loader", loader); + +// 执行Lua脚本 +ScriptExpression expression = new LuaExpression(context, scope, script); +LuaTable data = expression.doWith(LuaTable.class); +``` + +#### 6.6JStarCraft-RNS引擎与Python脚本交互 + +* [完整示例](https://github.com/HongZhaoHua/jstarcraft-rns/tree/master/src/test/java/com/jstarcraft/rns/script) + +* 编写Python脚本训练与评估模型并保存到Model.py文件 + +```python +# 构建配置 +keyValues = Properties() +keyValues.load(loader.getResourceAsStream("data.properties")) +keyValues.load(loader.getResourceAsStream("recommend/benchmark/randomguess-test.properties")) +configurator = Configurator([keyValues]) + +# 此对象会返回给Java程序 +_data = {} + +# 构建排序任务 +task = RankingTask(RandomGuessModel, configurator) +# 训练与评估模型并获取排序指标 +measures = task.execute() +_data['precision'] = measures.get(PrecisionEvaluator) +_data['recall'] = measures.get(RecallEvaluator) + +# 构建评分任务 +task = RatingTask(RandomGuessModel, configurator) +# 训练与评估模型并获取评分指标 +measures = task.execute() +_data['mae'] = measures.get(MAEEvaluator) +_data['mse'] = measures.get(MSEEvaluator) +``` + +* 使用JStarCraft框架从Model.py文件加载并执行Python脚本 + +```java +// 设置Python环境变量 +System.setProperty("python.console.encoding", StringUtility.CHARSET.name()); + +// 获取Python脚本 +File file = new File(PythonTestCase.class.getResource("Model.py").toURI()); +String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + +// 设置Python脚本使用到的Java类 +ScriptContext context = new ScriptContext(); +context.useClasses(Properties.class, Assert.class); +context.useClass("Configurator", MapConfigurator.class); +context.useClasses("com.jstarcraft.ai.evaluate"); +context.useClasses("com.jstarcraft.rns.task"); +context.useClasses("com.jstarcraft.rns.model.benchmark"); +// 设置Python脚本使用到的Java变量 +ScriptScope scope = new ScriptScope(); +scope.createAttribute("loader", loader); + +// 执行Python脚本 +ScriptExpression expression = new PythonExpression(context, scope, script); +Map data = expression.doWith(Map.class); +``` + +#### 6.7JStarCraft-Ruby + +* [完整示例](https://github.com/HongZhaoHua/jstarcraft-rns/tree/master/src/test/java/com/jstarcraft/rns/script) + +* 编写Ruby脚本训练与评估模型并保存到Model.rb文件 + +```ruby +# 构建配置 +keyValues = Properties.new() +keyValues.load($loader.getResourceAsStream("data.properties")) +keyValues.load($loader.getResourceAsStream("model/benchmark/randomguess-test.properties")) +configurator = Configurator.new(keyValues) + +# 此对象会返回给Java程序 +_data = Hash.new() + +# 构建排序任务 +task = RankingTask.new(RandomGuessModel.java_class, configurator) +# 训练与评估模型并获取排序指标 +measures = task.execute() +_data['precision'] = measures.get(PrecisionEvaluator.java_class) +_data['recall'] = measures.get(RecallEvaluator.java_class) + +# 构建评分任务 +task = RatingTask.new(RandomGuessModel.java_class, configurator) +# 训练与评估模型并获取评分指标 +measures = task.execute() +_data['mae'] = measures.get(MAEEvaluator.java_class) +_data['mse'] = measures.get(MSEEvaluator.java_class) + +_data; +``` + +* 使用JStarCraft框架从Model.rb文件加载并执行Ruby脚本 + +```java +// 获取Ruby脚本 +File file = new File(ScriptTestCase.class.getResource("Model.rb").toURI()); +String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + +// 设置Ruby脚本使用到的Java类 +ScriptContext context = new ScriptContext(); +context.useClasses(Properties.class, Assert.class); +context.useClass("Configurator", MapConfigurator.class); +context.useClasses("com.jstarcraft.ai.evaluate"); +context.useClasses("com.jstarcraft.rns.task"); +context.useClasses("com.jstarcraft.rns.model.benchmark"); +// 设置Ruby脚本使用到的Java变量 +ScriptScope scope = new ScriptScope(); +scope.createAttribute("loader", loader); + +// 执行Ruby脚本 +ScriptExpression expression = new RubyExpression(context, scope, script); +Map data = expression.doWith(Map.class); +Assert.assertEquals(0.005825241096317768D, data.get("precision"), 0D); +Assert.assertEquals(0.011579763144254684D, data.get("recall"), 0D); +Assert.assertEquals(1.270874261856079D, data.get("mae"), 0D); +Assert.assertEquals(2.425075054168701D, data.get("mse"), 0D); +``` + +**** + +## 7.对比 + +#### 7.1排序模型对比 + +* 基准模型 + +| 名称 | 数据集 | 训练 (毫秒) | 预测 (毫秒) | AUC | MAP | MRR | NDCG | Novelty | Precision | Recall | +| :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | +| MostPopular | filmtrust | 43 | 273 | 0.92080 | 0.41246 | 0.57196 | 0.51583 | 11.79295 | 0.33230 | 0.62385 | +| RandomGuess | filmtrust | 38 | 391 | 0.51922 | 0.00627 | 0.02170 | 0.01121 | 91.94900 | 0.00550 | 0.01262 | + +* 协同模型 + +| 名称 | 数据集 | 训练 (毫秒) | 预测 (毫秒) | AUC | MAP | MRR | NDCG | Novelty | Precision | Recall | +| :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | +| AoBPR | filmtrust | 12448 | 253 | 0.89324 | 0.38967 | 0.53990 | 0.48338 | 21.13004 | 0.32295 | 0.56864 | +| AspectRanking | filmtrust | 177 | 58 | 0.85130 | 0.15498 | 0.42480 | 0.26012 | 37.36273 | 0.13302 | 0.31292 | +| BHFreeRanking | filmtrust | 5720 | 4257 | 0.92080 | 0.41316 | 0.57231 | 0.51662 | 11.79567 | 0.33276 | 0.62500 | +| BPR | filmtrust | 4228 | 137 | 0.89390 | 0.39886 | 0.54790 | 0.49180 | 21.46738 | 0.32268 | 0.57623 | +| BUCMRanking | filmtrust | 2111 | 1343 | 0.90782 | 0.39794 | 0.55776 | 0.49651 | 13.08073 | 0.32407 | 0.59141 | +| CDAE | filmtrust | 89280 | 376 | 0.91880 | 0.40759 | 0.56855 | 0.51089 | 11.82466 | 0.33051 | 0.61967 | +| CLiMF | filmtrust | 48429 | 140 | 0.88293 | 0.37395 | 0.52407 | 0.46572 | 19.38964 | 0.32049 | 0.54605 | +| DeepFM | filmtrust | 69264 | 99 | 0.91679 | 0.40580 | 0.56995 | 0.50985 | 11.90242 | 0.32719 | 0.61426 | +| EALS | filmtrust | 850 | 185 | 0.86132 | 0.31263 | 0.45680 | 0.39475 | 20.08964 | 0.27381 | 0.46271 | +| FISMAUC | filmtrust | 2338 | 663 | 0.91216 | 0.40032 | 0.55730 | 0.50114 | 12.07469 | 0.32845 | 0.60294 | +| FISMRMSE | filmtrust | 4030 | 729 | 0.91482 | 0.40795 | 0.56470 | 0.50920 | 11.91234 | 0.33044 | 0.61107 | +| GBPR | filmtrust | 14827 | 150 | 0.92113 | 0.41003 | 0.57144 | 0.51464 | 11.87609 | 0.33090 | 0.62512 | +| HMM | game | 38697 | 11223 | 0.80559 | 0.18156 | 0.37516 | 0.25803 | 16.01041 | 0.14572 | 0.22810 | +| ItemBigram | filmtrust | 12492 | 61 | 0.88807 | 0.33520 | 0.46870 | 0.42854 | 17.11172 | 0.29191 | 0.53308 | +| ItemKNNRanking | filmtrust | 2683 | 250 | 0.87438 | 0.33375 | 0.46951 | 0.41767 | 20.23449 | 0.28581 | 0.49248 | +| LDA | filmtrust | 696 | 161 | 0.91980 | 0.41758 | 0.58130 | 0.52003 | 12.31348 | 0.33336 | 0.62274 | +| LambdaFMStatic | game | 25052 | 27078 | 0.87064 | 0.27294 | 0.43640 | 0.34794 | 16.47330 | 0.13941 | 0.35696 | +| LambdaFMWeight | game | 25232 | 28156 | 0.87339 | 0.27333 | 0.43720 | 0.34728 | 14.71413 | 0.13742 | 0.35252 | +| LambdaFMDynamic | game | 74218 | 27921 | 0.87380 | 0.27288 | 0.43648 | 0.34706 | 13.50578 | 0.13822 | 0.35132 | +| ListwiseMF | filmtrust | 714 | 161 | 0.90820 | 0.40511 | 0.56619 | 0.50521 | 15.53665 | 0.32944 | 0.60092 | +| PLSA | filmtrust | 1027 | 116 | 0.89950 | 0.41217 | 0.57187 | 0.50597 | 16.01080 | 0.32401 | 0.58557 | +| RankALS | filmtrust | 3285 | 182 | 0.85901 | 0.29255 | 0.51014 | 0.38871 | 25.27197 | 0.22931 | 0.42509 | +| RankCD | product | 1442 | 8905 | 0.56271 | 0.01253 | 0.04618 | 0.02682 | 55.42019 | 0.01548 | 0.03520 | +| RankSGD | filmtrust | 309 | 113 | 0.80388 | 0.23587 | 0.42290 | 0.32081 | 42.83305 | 0.19363 | 0.35374 | +| RankVFCD | product | 54273 | 6524 | 0.58022 | 0.01784 | 0.06181 | 0.03664 | 62.95810 | 0.01980 | 0.04852 | +| SLIM | filmtrust | 62434 | 91 | 0.91849 | 0.44851 | 0.61083 | 0.54557 | 16.67990 | 0.34019 | 0.63021 | +| UserKNNRanking | filmtrust | 1154 | 229 | 0.90752 | 0.41616 | 0.57525 | 0.51393 | 12.90921 | 0.32891 | 0.60152 | +| VBPR | product | 184473 | 15304 | 0.54336 | 0.00920 | 0.03522 | 0.01883 | 45.05101 | 0.01037 | 0.02266 | +| WBPR | filmtrust | 20705 | 183 | 0.78072 | 0.24647 | 0.33373 | 0.30442 | 17.18609 | 0.25000 | 0.35516 | +| WRMF | filmtrust | 482 | 158 | 0.90616 | 0.43278 | 0.58284 | 0.52480 | 15.17956 | 0.32918 | 0.60780 | +| RankGeoFM | FourSquare | 368436 | 1093 | 0.72708 | 0.05485 | 0.24012 | 0.11057 | 37.50040 | 0.07866 | 0.08640 | +| SBPR | filmtrust | 41481 | 247 | 0.91010 | 0.41189 | 0.56480 | 0.50726 | 15.67905 | 0.32440 | 0.59699 | + +* 内容模型 + +| 名称 | 数据集 | 训练 (毫秒) | 预测 (毫秒) | AUC | MAP | MRR | NDCG | Novelty | Precision | Recall | +| :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | +| EFMRanking | dc_dense | 2066 | 2276 | 0.61271 | 0.01611 | 0.04631 | 0.04045 | 53.26140 | 0.02387 | 0.07357 | +| TFIDF | musical_instruments | 942 | 1085 | 0.52756 | 0.01067 | 0.01917 | 0.01773 | 72.71228 | 0.00588 | 0.03103 | + +#### 7.2评分模型对比 + +* 基准模型 + +| 名称 | 数据集 | 训练 (毫秒) | 预测 (毫秒) | MAE | MPE | MSE | +| :----: | :----: | :----: | :----: | :----: | :----: | :----: | +| ConstantGuess | filmtrust | 137 | 45 | 1.05608 | 1.00000 | 1.42309 | +| GlobalAverage | filmtrust | 60 | 13 | 0.71977 | 0.77908 | 0.85199 | +| ItemAverage | filmtrust | 59 | 12 | 0.72968 | 0.97242 | 0.86413 | +| ItemCluster | filmtrust | 471 | 41 | 0.71976 | 0.77908 | 0.85198 | +| RandomGuess | filmtrust | 38 | 8 | 1.28622 | 0.99597 | 2.47927 | +| UserAverage | filmtrust | 35 | 9 | 0.64618 | 0.97242 | 0.70172 | +| UserCluster | filmtrust | 326 | 45 | 0.71977 | 0.77908 | 0.85199 | + +* 协同模型 + +| 名称 | 数据集 | 训练 (毫秒) | 预测 (毫秒) | MAE | MPE | MSE | +| :----: | :----: | :----: | :----: | :----: | :----: | :----: | +| AspectRating | filmtrust | 220 | 5 | 0.65754 | 0.97918 | 0.71809 | +| ASVDPlusPlus | filmtrust | 5631 | 8 | 0.71975 | 0.77921 | 0.85196 | +| BiasedMF | filmtrust | 92 | 6 | 0.63157 | 0.98387 | 0.66220 | +| BHFreeRating | filmtrust | 6667 | 76 | 0.71974 | 0.77908 | 0.85198 | +| BPMF | filmtrust | 25942 | 52 | 0.66504 | 0.98465 | 0.70210 | +| BUCMRating | filmtrust | 1843 | 30 | 0.64834 | 0.99102 | 0.67992 | +| CCD | product | 15715 | 9 | 0.96670 | 0.93947 | 1.62145 | +| FFM | filmtrust | 5422 | 6 | 0.63446 | 0.98413 | 0.66682 | +| FMALS | filmtrust | 1854 | 5 | 0.64788 | 0.96032 | 0.73636 | +| FMSGD | filmtrust | 3496 | 10 | 0.63452 | 0.98426 | 0.66710 | +| GPLSA | filmtrust | 2567 | 7 | 0.67311 | 0.98972 | 0.79883 | +| IRRG | filmtrust | 40284 | 6 | 0.64766 | 0.98777 | 0.73700 | +| ItemKNNRating | filmtrust | 2052 | 27 | 0.62341 | 0.95394 | 0.67312 | +| LDCC | filmtrust | 8650 | 84 | 0.66383 | 0.99284 | 0.70666 | +| LLORMA | filmtrust | 16618 | 82 | 0.64930 | 0.96591 | 0.76067 | +| MFALS | filmtrust | 2944 | 5 | 0.82939 | 0.94549 | 1.30547 | +| NMF | filmtrust | 1198 | 8 | 0.67661 | 0.96604 | 0.83493 | +| PMF | filmtrust | 215 | 7 | 0.72959 | 0.98165 | 0.99948 | +| RBM | filmtrust | 19551 | 270 | 0.74484 | 0.98504 | 0.88968 | +| RFRec | filmtrust | 16330 | 54 | 0.64008 | 0.97112 | 0.69390 | +| SVDPlusPlus | filmtrust | 452 | 26 | 0.65248 | 0.99141 | 0.68289 | +| URP | filmtrust | 1514 | 25 | 0.64207 | 0.99128 | 0.67122 | +| UserKNNRating | filmtrust | 1121 | 135 | 0.63933 | 0.94640 | 0.69280 | +| RSTE | filmtrust | 4052 | 10 | 0.64303 | 0.99206 | 0.67777 | +| SocialMF | filmtrust | 918 | 13 | 0.64668 | 0.98881 | 0.68228 | +| SoRec | filmtrust | 1048 | 10 | 0.64305 | 0.99232 | 0.67776 | +| SoReg | filmtrust | 635 | 8 | 0.65943 | 0.96734 | 0.72760 | +| TimeSVD | filmtrust | 11545 | 36 | 0.68954 | 0.93326 | 0.87783 | +| TrustMF | filmtrust | 2038 | 7 | 0.63787 | 0.98985 | 0.69017 | +| TrustSVD | filmtrust | 12465 | 22 | 0.61984 | 0.98933 | 0.63875 | +| AssociationRule | filmtrust | 2628 | 195 | 0.90853 | 0.41801 | 0.57777 | 0.51621 | 12.65794 | 0.33263 | 0.60700 | +| PersonalityDiagnosis | filmtrust | 45 | 642 | 0.72964 | 0.76620 | 1.03071 | +| PRankD | filmtrust | 3321 | 170 | 0.74472 | 0.22894 | 0.32406 | 0.28390 | 45.81069 | 0.19436 | 0.32904 | +| SlopeOne | filmtrust | 135 | 28 | 0.63788 | 0.96175 | 0.71057 | + +* 内容模型 + +| 名称 | 数据集 | 训练 (毫秒) | 预测 (毫秒) | MAE | MPE | MSE | +| :----: | :----: | :----: | :----: | :----: | :----: | :----: | +| EFMRating | dc_dense | 659 | 8 | 0.61546 | 0.85364 | 0.78279 | +| HFT | musical_instruments | 162753 | 13 | 0.64272 | 0.94886 | 0.81393 | +| TopicMFAT | musical_instruments | 6907 | 7 | 0.61896 | 0.98734 | 0.72545 | +| TopicMFMT | musical_instruments | 6323 | 7 | 0.61896 | 0.98734 | 0.72545 | + +## 8.参考 + +#### 8.1个性化模型说明 + +* 基准模型 + +| 名称 | 问题 | 说明/论文 | +| :----: | :----: | :----: | +| RandomGuess | Ranking Rating | 随机猜测 | +| MostPopular | Ranking| 最受欢迎 | +| ConstantGuess | Rating | 常量猜测 | +| GlobalAverage | Rating | 全局平均 | +| ItemAverage | Rating | 物品平均 | +| ItemCluster | Rating | 物品聚类 | +| UserAverage | Rating | 用户平均 | +| UserCluster | Rating | 用户聚类 | + +* 协同模型 + +| 名称 | 问题 | 说明/论文 | +| :----: | :----: | :----: | +| AspectModel | Ranking Rating | Latent class models for collaborative filtering | +| BHFree | Ranking Rating | Balancing Prediction and Recommendation Accuracy: Hierarchical Latent Factors for Preference Data | +| BUCM | Ranking Rating | Modeling Item Selection and Relevance for Accurate Recommendations | +| ItemKNN | Ranking Rating | 基于物品的协同过滤 | +| UserKNN | Ranking Rating | 基于用户的协同过滤 | +| AoBPR | Ranking | Improving pairwise learning for item recommendation from implicit feedback | +| BPR | Ranking | BPR: Bayesian Personalized Ranking from Implicit Feedback | +| CLiMF | Ranking | CLiMF: learning to maximize reciprocal rank with collaborative less-is-more filtering | +| EALS | Ranking | Collaborative filtering for implicit feedback dataset | +| FISM | Ranking | FISM: Factored Item Similarity Models for Top-N Recommender Systems | +| GBPR | Ranking | GBPR: Group Preference Based Bayesian Personalized Ranking for One-Class Collaborative Filtering | +| HMMForCF | Ranking | A Hidden Markov Model Purpose: A class for the model, including parameters | +| ItemBigram | Ranking | Topic Modeling: Beyond Bag-of-Words | +| LambdaFM | Ranking | LambdaFM: Learning Optimal Ranking with Factorization Machines Using Lambda Surrogates | +| LDA | Ranking | Latent Dirichlet Allocation for implicit feedback | +| ListwiseMF | Ranking | List-wise learning to rank with matrix factorization for collaborative filtering | +| PLSA | Ranking | Latent semantic models for collaborative filtering | +| RankALS | Ranking | Alternating Least Squares for Personalized Ranking | +| RankSGD | Ranking | Collaborative Filtering Ensemble for Ranking | +| SLIM | Ranking | SLIM: Sparse Linear Methods for Top-N Recommender Systems | +| WBPR | Ranking | Bayesian Personalized Ranking for Non-Uniformly Sampled Items | +| WRMF | Ranking | Collaborative filtering for implicit feedback datasets | +| Rank-GeoFM | Ranking | Rank-GeoFM: A ranking based geographical factorization method for point of interest recommendation | +| SBPR | Ranking | Leveraging Social Connections to Improve Personalized Ranking for Collaborative Filtering | +| AssociationRule | Ranking | A Recommendation Algorithm Using Multi-Level Association Rules | +| PRankD | Ranking | Personalised ranking with diversity | +| AsymmetricSVD++ | Rating | Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model | +| AutoRec | Rating | AutoRec: Autoencoders Meet Collaborative Filtering | +| BPMF | Rating | Bayesian Probabilistic Matrix Factorization using Markov Chain Monte Carlo | +| CCD | Rating | Large-Scale Parallel Collaborative Filtering for the Netflix Prize | +| FFM | Rating | Field Aware Factorization Machines for CTR Prediction | +| GPLSA | Rating | Collaborative Filtering via Gaussian Probabilistic Latent Semantic Analysis | +| IRRG | Rating | Exploiting Implicit Item Relationships for Recommender Systems | +| MFALS | Rating | Large-Scale Parallel Collaborative Filtering for the Netflix Prize | +| NMF | Rating | Algorithms for Non-negative Matrix Factorization | +| PMF | Rating | PMF: Probabilistic Matrix Factorization | +| RBM | Rating | Restricted Boltzman Machines for Collaborative Filtering | +| RF-Rec | Rating | RF-Rec: Fast and Accurate Computation of Recommendations based on Rating Frequencies | +| SVD++ | Rating | Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model | +| URP | Rating | User Rating Profile: a LDA model for rating prediction | +| RSTE | Rating | Learning to Recommend with Social Trust Ensemble | +| SocialMF | Rating | A matrix factorization technique with trust propagation for recommendation in social networks | +| SoRec | Rating | SoRec: Social recommendation using probabilistic matrix factorization | +| SoReg | Rating | Recommender systems with social regularization | +| TimeSVD++ | Rating | Collaborative Filtering with Temporal Dynamics | +| TrustMF | Rating | Social Collaborative Filtering by Trust | +| TrustSVD | Rating | TrustSVD: Collaborative Filtering with Both the Explicit and Implicit Influence of User Trust and of Item Ratings | +| PersonalityDiagnosis | Rating | A brief introduction to Personality Diagnosis | +| SlopeOne | Rating | Slope One Predictors for Online Rating-Based Collaborative Filtering | + +* 内容模型 + +| 名称 | 问题 | 说明/论文 | +| :----: | :----: | :----: | +| EFM | Ranking Rating | Explicit factor models for explainable recommendation based on phrase-level sentiment analysis | +| TF-IDF | Ranking | 词频-逆文档频率 | +| HFT | Rating | Hidden factors and hidden topics: understanding rating dimensions with review text | +| TopicMF | Rating | TopicMF: Simultaneously Exploiting Ratings and Reviews for Recommendation | + +#### 8.2数据集说明 + +* [Amazon Dataset](http://jmcauley.ucsd.edu/data/amazon/) +* [Bibsonomy Dataset](https://www.kde.cs.uni-kassel.de/wp-content/uploads/bibsonomy/) +* [BookCrossing Dataset](https://grouplens.org/datasets/book-crossing/) +* [Ciao Dataset](https://www.cse.msu.edu/~tangjili/datasetcode/truststudy.htm) +* [Douban Dataset](http://smiles.xjtu.edu.cn/Download/Download_Douban.html) +* [Eachmovie Dataset](https://grouplens.org/datasets/eachmovie/) +* [Epinions Dataset](http://www.trustlet.org/epinions.html) +* [Foursquare Dataset](https://sites.google.com/site/yangdingqi/home/foursquare-dataset) +* [Goodbooks Dataset](http://fastml.com/goodbooks-10k-a-new-dataset-for-book-recommendations/) +* [Gowalla Dataset](http://snap.stanford.edu/data/loc-gowalla.html) +* [HetRec2011 Dataset](https://grouplens.org/datasets/hetrec-2011/) +* [Jest Joker Dataset](https://grouplens.org/datasets/jester/) +* [Large Movie Review Dataset](http://ai.stanford.edu/~amaas/data/sentiment/) +* [MovieLens Dataset](https://grouplens.org/datasets/movielens/) +* [Newsgroups Dataset](http://qwone.com/~jason/20Newsgroups/) +* [Stanford Large Network Dataset](http://snap.stanford.edu/data/) +* [Serendipity 2018 Dataset](https://grouplens.org/datasets/serendipity-2018/) +* [Wikilens Dataset](https://grouplens.org/datasets/wikilens/) +* [Yelp Dataset](https://www.yelp.com/dataset) +* [Yongfeng Zhang Dataset](http://yongfeng.me/dataset/) + -JStarCraft RNS是一个面向信息检索领域的轻量级引擎.遵循Apache 2.0协议. \ No newline at end of file diff --git a/data.7z b/data.7z new file mode 100644 index 0000000..1ab6a8a Binary files /dev/null and b/data.7z differ diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..1147759 --- /dev/null +++ b/pom.xml @@ -0,0 +1,156 @@ + + + 4.0.0 + + com.jstarcraft + jstarcraft-rns + 1.0 + jar + + + UTF-8 + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.0 + + 1.8 + 1.8 + UTF-8 + + + + + + + + + jdk.tools + jdk.tools + 1.8 + system + ${JAVA_HOME}/lib/tools.jar + + + + + com.jstarcraft + jstarcraft-core-script + 1.0 + + + + com.jstarcraft + jstarcraft-ai-model + 1.0 + + + + + ch.obermuhlner + big-math + 2.1.0 + + + + + junit + junit + 4.13.1 + test + + + + org.springframework + spring-test + 5.1.6.RELEASE + test + + + + net.sourceforge.jdistlib + jdistlib + 0.4.5 + test + + + + org.nd4j + nd4j-native-platform + 1.0.0-beta3 + test + + + + + + + org.apache.logging.log4j + log4j-slf4j-impl + 2.11.2 + test + + + + + org.apache.logging.log4j + log4j-jcl + 2.11.2 + test + + + + + org.apache-extras.beanshell + bsh + 2.0b6 + test + + + + org.codehaus.groovy + groovy-all + 2.4.16 + test + + + + org.jetbrains.kotlin + kotlin-scripting-jsr223 + 1.4.21 + test + + + + org.luaj + luaj-jse + 3.0.1 + test + + + + org.mvel + mvel2 + 2.4.4.Final + test + + + + org.python + jython-standalone + 2.7.1 + test + + + + org.jruby + jruby-complete + 9.2.11.1 + test + + + diff --git a/src/main/java/com/jstarcraft/rns/data/processor/AllFeatureDataSorter.java b/src/main/java/com/jstarcraft/rns/data/processor/AllFeatureDataSorter.java new file mode 100644 index 0000000..b69b0b4 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/data/processor/AllFeatureDataSorter.java @@ -0,0 +1,27 @@ +package com.jstarcraft.rns.data.processor; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.processor.DataSorter; + +public class AllFeatureDataSorter implements DataSorter { + + @Override + public int sort(DataInstance left, DataInstance right) { + for (int dimension = 0, order = left.getQualityOrder(); dimension < order; dimension++) { + int leftValue = left.getQualityFeature(dimension); + int rightValue = right.getQualityFeature(dimension); + if (leftValue != rightValue) { + return leftValue < rightValue ? -1 : 1; + } + } + for (int dimension = 0, order = right.getQuantityOrder(); dimension < order; dimension++) { + float leftValue = left.getQuantityFeature(dimension); + float rightValue = right.getQuantityFeature(dimension); + if (leftValue != rightValue) { + return leftValue < rightValue ? -1 : 1; + } + } + return 0; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/data/processor/QualityFeatureDataSorter.java b/src/main/java/com/jstarcraft/rns/data/processor/QualityFeatureDataSorter.java new file mode 100644 index 0000000..41946e5 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/data/processor/QualityFeatureDataSorter.java @@ -0,0 +1,24 @@ +package com.jstarcraft.rns.data.processor; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.processor.DataSorter; + +public class QualityFeatureDataSorter implements DataSorter { + + private int dimension; + + public QualityFeatureDataSorter(int dimension) { + this.dimension = dimension; + } + + @Override + public int sort(DataInstance left, DataInstance right) { + int leftValue = left.getQualityFeature(dimension); + int rightValue = right.getQualityFeature(dimension); + if (leftValue != rightValue) { + return leftValue < rightValue ? -1 : 1; + } + return 0; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/data/processor/QualityFeatureDataSplitter.java b/src/main/java/com/jstarcraft/rns/data/processor/QualityFeatureDataSplitter.java new file mode 100644 index 0000000..7b1cff0 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/data/processor/QualityFeatureDataSplitter.java @@ -0,0 +1,19 @@ +package com.jstarcraft.rns.data.processor; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.processor.DataSplitter; + +public class QualityFeatureDataSplitter implements DataSplitter { + + private int dimension; + + public QualityFeatureDataSplitter(int dimension) { + this.dimension = dimension; + } + + @Override + public int split(DataInstance instance) { + return instance.getQualityFeature(dimension); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/data/processor/QuantityFeatureDataSorter.java b/src/main/java/com/jstarcraft/rns/data/processor/QuantityFeatureDataSorter.java new file mode 100644 index 0000000..1dfc055 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/data/processor/QuantityFeatureDataSorter.java @@ -0,0 +1,24 @@ +package com.jstarcraft.rns.data.processor; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.processor.DataSorter; + +public class QuantityFeatureDataSorter implements DataSorter { + + private int dimension; + + public QuantityFeatureDataSorter(int dimension) { + this.dimension = dimension; + } + + @Override + public int sort(DataInstance left, DataInstance right) { + float leftValue = left.getQuantityFeature(dimension); + float rightValue = right.getQuantityFeature(dimension); + if (leftValue != rightValue) { + return leftValue < rightValue ? -1 : 1; + } + return 0; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/data/processor/RandomDataSorter.java b/src/main/java/com/jstarcraft/rns/data/processor/RandomDataSorter.java new file mode 100644 index 0000000..49a611d --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/data/processor/RandomDataSorter.java @@ -0,0 +1,35 @@ +package com.jstarcraft.rns.data.processor; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.IntegerArray; +import com.jstarcraft.ai.data.module.ReferenceModule; +import com.jstarcraft.ai.data.processor.DataSorter; +import com.jstarcraft.core.utility.RandomUtility; + +public class RandomDataSorter implements DataSorter { + + @Override + public int sort(DataInstance left, DataInstance right) { + throw new UnsupportedOperationException(); + } + + @Override + public ReferenceModule sort(DataModule module) { + int size = module.getSize(); + IntegerArray reference = new IntegerArray(size, size); + for (int index = 0; index < size; index++) { + reference.associateData(index); + } + int from = 0; + int to = size; + for (int index = from; index < to; index++) { + int random = RandomUtility.randomInteger(from, to); + int data = reference.getData(index); + reference.setData(index, reference.getData(random)); + reference.setData(random, data); + } + return new ReferenceModule(reference, module); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/AbstractModel.java b/src/main/java/com/jstarcraft/rns/model/AbstractModel.java new file mode 100644 index 0000000..74f715e --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/AbstractModel.java @@ -0,0 +1,100 @@ +package com.jstarcraft.rns.model; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.processor.DataSorter; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.ai.environment.EnvironmentContext; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.rns.data.processor.AllFeatureDataSorter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; + +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * 抽象推荐器 + * + * @author Birdy + * + */ +public abstract class AbstractModel implements Model { + + protected final Logger logger = LoggerFactory.getLogger(this.getClass()); + + protected String id; + + // 参数部分 + /** 用户字段, 物品字段, 分数字段 */ + protected String userField, itemField; + + /** 用户维度, 物品维度 */ + protected int userDimension, itemDimension; + + /** 用户数量, 物品数量 */ + protected int userSize, itemSize; + + /** 最低分数, 最高分数, 平均分数 */ + protected float minimumScore, maximumScore, meanScore; + + /** 行为数量(TODO 此字段可能迁移到其它类.为避免重复行为,一般使用matrix或者tensor的元素数量) */ + protected int actionSize; + + /** 训练矩阵(TODO 准备改名为actionMatrix或者scoreMatrix) */ + protected SparseMatrix scoreMatrix; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + userField = configuration.getString("data.model.fields.user", "user"); + itemField = configuration.getString("data.model.fields.item", "item"); + + userDimension = model.getQualityInner(userField); + itemDimension = model.getQualityInner(itemField); + userSize = space.getQualityAttribute(userField).getSize(); + itemSize = space.getQualityAttribute(itemField).getSize(); + + DataSplitter splitter = new QualityFeatureDataSplitter(userDimension); + DataModule[] models = splitter.split(model, userSize); + DataSorter sorter = new AllFeatureDataSorter(); + for (int index = 0; index < userSize; index++) { + models[index] = sorter.sort(models[index]); + } + + HashMatrix dataTable = new HashMatrix(true, userSize, itemSize, new Long2FloatRBTreeMap()); + for (DataInstance instance : model) { + int rowIndex = instance.getQualityFeature(userDimension); + int columnIndex = instance.getQualityFeature(itemDimension); + dataTable.setValue(rowIndex, columnIndex, instance.getQuantityMark()); + } + scoreMatrix = SparseMatrix.valueOf(userSize, itemSize, dataTable); + actionSize = scoreMatrix.getElementSize(); + KeyValue attribute = scoreMatrix.getBoundary(false); + minimumScore = attribute.getKey(); + maximumScore = attribute.getValue(); + meanScore = scoreMatrix.getSum(false); + meanScore /= actionSize; + } + + protected abstract void doPractice(); + + protected void constructEnvironment() { + } + + protected void destructEnvironment() { + } + + @Override + public final void practice() { + EnvironmentContext context = EnvironmentContext.getContext(); + context.doAlgorithmByEvery(this::constructEnvironment); + doPractice(); + context.doAlgorithmByEvery(this::destructEnvironment); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/EpocheModel.java b/src/main/java/com/jstarcraft/rns/model/EpocheModel.java new file mode 100644 index 0000000..a82ecbd --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/EpocheModel.java @@ -0,0 +1,104 @@ +package com.jstarcraft.rns.model; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.MathUtility; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.StringUtility; +import com.jstarcraft.rns.model.exception.ModelException; +import com.jstarcraft.rns.utility.LogisticUtility; + +/** + * 模型推荐器 + * + *
+ * 与机器学习相关
+ * 
+ * + * @author Birdy + * + */ +public abstract class EpocheModel extends AbstractModel { + + /** 周期次数 */ + protected int epocheSize; + + /** 是否收敛(early-stop criteria) */ + protected boolean isConverged; + + /** 用于观察损失率 */ + protected float totalError, currentError; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // 参数部分 + epocheSize = configuration.getInteger("recommender.iterator.maximum", 100); + isConverged = configuration.getBoolean("recommender.recommender.earlystop", false); + } + + /** + * 是否收敛 + * + * @param iteration + * @return + */ + protected boolean isConverged(int iteration) { + float deltaError = currentError - totalError; + // print out debug info + if (logger.isInfoEnabled()) { + String message = StringUtility.format("{} : epoch is {}, total is {}, delta is {}", getClass().getSimpleName(), iteration, totalError, deltaError); + logger.info(message); + } + if (Float.isNaN(totalError) || Float.isInfinite(totalError)) { + throw new ModelException("Loss = NaN or Infinity: current settings does not fit the recommender! Change the settings and try again!"); + } + // check if converged + boolean converged = Math.abs(deltaError) < MathUtility.EPSILON; + return converged; + } + + /** + * fajie To calculate cmg based on pairwise loss function type + * + * @param lossType + * @param error + * @return + */ + protected final float calaculateGradientValue(int lossType, float error) { + final float constant = 1F; + float value = 0F; + switch (lossType) { + case 0:// Hinge loss + if (constant * error <= 1F) + value = constant; + break; + case 1:// Rennie loss + if (constant * error <= 0F) + value = -constant; + else if (constant * error <= 1F) + value = (1F - constant * error) * (-constant); + else + value = 0F; + value = -value; + break; + case 2:// logistic loss, BPR + value = LogisticUtility.getValue(-error); + break; + case 3:// Frank loss + value = (float) (Math.sqrt(LogisticUtility.getValue(error)) / (1F + Math.exp(error))); + break; + case 4:// Exponential loss + value = (float) Math.exp(-error); + break; + case 5:// quadratically smoothed + if (error <= 1F) + value = 0.5F * (1F - error); + break; + default: + break; + } + return value; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/FactorizationMachineModel.java b/src/main/java/com/jstarcraft/rns/model/FactorizationMachineModel.java new file mode 100644 index 0000000..f224760 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/FactorizationMachineModel.java @@ -0,0 +1,211 @@ +package com.jstarcraft.rns.model; + +import java.util.Map.Entry; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.random.JDKRandomGenerator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.algorithm.probability.QuantityProbability; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.ArrayVector; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.rns.model.exception.ModelException; + +/** + * Factorization Machine Recommender + * + * Rendle, Steffen, et al., Fast Context-aware Recommendations with + * Factorization Machines, SIGIR, 2011. + * + * @author Tang Jiaxi and Ma Chen + */ +// TODO 论文中需要支持组合特征(比如:历史评价过的电影),现在的代码并没有实现. +public abstract class FactorizationMachineModel extends EpocheModel { + + /** 是否自动调整学习率 */ + protected boolean isLearned; + + /** 衰减率 */ + protected float learnDecay; + + /** + * learn rate, maximum learning rate + */ + protected float learnRatio, learnLimit; + + protected DataModule marker; + + /** + * global bias + */ + protected float globalBias; + /** + * appender vector size: number of users + number of items + number of + * contextual conditions + */ + protected int featureSize; + /** + * number of factors + */ + protected int factorSize; + + /** + * weight vector + */ + protected DenseVector weightVector; // p + /** + * parameter matrix(featureFactors) + */ + protected DenseMatrix featureFactors; // p x k + /** + * parameter matrix(rateFactors) + */ + protected DenseMatrix actionFactors; // n x k + /** + * regularization term for weight and factors + */ + protected float biasRegularization, weightRegularization, factorRegularization; + + /** + * init mean + */ + protected float initMean; + + /** + * init standard deviation + */ + protected float initStd; + + protected QuantityProbability distribution; + + protected int[] dimensionSizes; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + isLearned = configuration.getBoolean("recommender.learnrate.bolddriver", false); + learnDecay = configuration.getFloat("recommender.learnrate.decay", 1.0f); + learnRatio = configuration.getFloat("recommender.iterator.learnrate", 0.01f); + learnLimit = configuration.getFloat("recommender.iterator.learnrate.maximum", 1000.0f); + + maximumScore = configuration.getFloat("recommender.recommender.maxrate", 12F); + minimumScore = configuration.getFloat("recommender.recommender.minrate", 0F); + + factorSize = configuration.getInteger("recommender.factor.number"); + + // init all weight with zero + globalBias = 0; + + // init factors with small value + // TODO 此处需要重构 + initMean = configuration.getFloat("recommender.init.mean", 0F); + initStd = configuration.getFloat("recommender.init.std", 0.1F); + + biasRegularization = configuration.getFloat("recommender.fm.regw0", 0.01F); + weightRegularization = configuration.getFloat("recommender.fm.regW", 0.01F); + factorRegularization = configuration.getFloat("recommender.fm.regF", 10F); + + // TODO 暂时不支持连续特征,考虑将连续特征离散化. + this.marker = model; + dimensionSizes = new int[marker.getQualityOrder()]; + + // TODO 考虑重构,在AbstractRecommender初始化 + actionSize = marker.getSize(); + // initialize the parameters of FM + // TODO 此处需要重构,外部索引与内部索引的映射转换 + for (int orderIndex = 0, orderSize = marker.getQualityOrder() + marker.getQuantityOrder(); orderIndex < orderSize; orderIndex++) { + Entry> term = marker.getOuterKeyValue(orderIndex); + if (term.getValue().getValue()) { + // 处理离散维度 + dimensionSizes[marker.getQualityInner(term.getValue().getKey())] = space.getQualityAttribute(term.getValue().getKey()).getSize(); + featureSize += dimensionSizes[marker.getQualityInner(term.getValue().getKey())]; + } else { + // 处理连续维度 + } + } + weightVector = DenseVector.valueOf(featureSize); + distribution = new QuantityProbability(JDKRandomGenerator.class, 0, NormalDistribution.class, initMean, initStd); + featureFactors = DenseMatrix.valueOf(featureSize, factorSize); + featureFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + } + + /** + * 获取特征向量 + * + *
+     * 实际为One Hot Encoding(一位有效编码)
+     * 详细原理与使用参考:http://blog.csdn.net/pipisorry/article/details/61193868
+     * 
+ * + * @param featureIndexes + * @return + */ + protected MathVector getFeatureVector(DataInstance instance) { + int orderSize = instance.getQualityOrder(); + int[] keys = new int[orderSize]; + int cursor = 0; + for (int orderIndex = 0; orderIndex < orderSize; orderIndex++) { + keys[orderIndex] += cursor + instance.getQualityFeature(orderIndex); + cursor += dimensionSizes[orderIndex]; + } + ArrayVector vector = new ArrayVector(featureSize, keys); + vector.setValues(1F); + return vector; + } + + /** + * Predict the rating given a sparse appender vector. + * + * @param userIndex user Id + * @param itemIndex item Id + * @param featureVector the given vector to predict. + * + * @return predicted rating + * @throws ModelException if error occurs + */ + protected float predict(DefaultScalar scalar, MathVector featureVector) { + float value = 0; + // global bias + value += globalBias; + // 1-way interaction + value += scalar.dotProduct(weightVector, featureVector).getValue(); + + // 2-way interaction + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float scoreSum = 0F; + float predictSum = 0F; + for (VectorScalar vectorTerm : featureVector) { + float featureValue = vectorTerm.getValue(); + int featureIndex = vectorTerm.getIndex(); + float predictValue = featureFactors.getValue(featureIndex, factorIndex); + + scoreSum += predictValue * featureValue; + predictSum += predictValue * predictValue * featureValue * featureValue; + } + value += (scoreSum * scoreSum - predictSum) / 2F; + } + + return value; + } + + @Override + public void predict(DataInstance instance) { + DefaultScalar scalar = DefaultScalar.getInstance(); + // TODO 暂时不支持连续特征,考虑将连续特征离散化. + MathVector featureVector = getFeatureVector(instance); + instance.setQuantityMark(predict(scalar, featureVector)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/MatrixFactorizationModel.java b/src/main/java/com/jstarcraft/rns/model/MatrixFactorizationModel.java new file mode 100644 index 0000000..cef8f90 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/MatrixFactorizationModel.java @@ -0,0 +1,167 @@ +package com.jstarcraft.rns.model; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.random.JDKRandomGenerator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.algorithm.probability.QuantityProbability; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; + +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; + +/** + * 矩阵分解推荐器 + * + * @author Birdy + * + */ +public abstract class MatrixFactorizationModel extends EpocheModel { + + /** 是否自动调整学习率 */ + protected boolean isLearned; + + /** 衰减率 */ + protected float learnDecay; + + /** + * learn rate, maximum learning rate + */ + protected float learnRatio, learnLimit; + + /** + * user latent factors + */ + protected DenseMatrix userFactors; + + /** + * item latent factors + */ + protected DenseMatrix itemFactors; + + /** + * the number of latent factors; + */ + protected int factorSize; + + /** + * user regularization + */ + protected float userRegularization; + + /** + * item regularization + */ + protected float itemRegularization; + + /** + * init mean + */ + protected float initMean; + + /** + * init standard deviation + */ + protected float initStd; + + protected QuantityProbability distribution; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + userRegularization = configuration.getFloat("recommender.user.regularization", 0.01f); + itemRegularization = configuration.getFloat("recommender.item.regularization", 0.01f); + + factorSize = configuration.getInteger("recommender.factor.number", 10); + + isLearned = configuration.getBoolean("recommender.learnrate.bolddriver", false); + learnDecay = configuration.getFloat("recommender.learnrate.decay", 1.0f); + learnRatio = configuration.getFloat("recommender.iterator.learnrate", 0.01f); + learnLimit = configuration.getFloat("recommender.iterator.learnrate.maximum", 1000.0f); + + // TODO 此处需要重构 + initMean = configuration.getFloat("recommender.init.mean", 0F); + initStd = configuration.getFloat("recommender.init.std", 0.1F); + + distribution = new QuantityProbability(JDKRandomGenerator.class, 0, NormalDistribution.class, initMean, initStd); + userFactors = DenseMatrix.valueOf(userSize, factorSize); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + } + + protected float predict(int userIndex, int itemIndex) { + DenseVector userVector = userFactors.getRowVector(userIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + DefaultScalar scalar = DefaultScalar.getInstance(); + return scalar.dotProduct(userVector, itemVector).getValue(); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex)); + } + + /** + * Update current learning rate after each epoch
+ *
    + *
  1. bold driver: Gemulla et al., Large-scale matrix factorization with + * distributed stochastic gradient descent, KDD 2011.
  2. + *
  3. constant decay: Niu et al, Hogwild!: A lock-free approach to + * parallelizing stochastic gradient descent, NIPS 2011.
  4. + *
  5. Leon Bottou, Stochastic Gradient Descent Tricks
  6. + *
  7. more ways to adapt learning rate can refer to: + * http://www.willamette.edu/~gorr/classes/cs449/momrate.html
  8. + *
+ * + * @param iteration the current iteration + */ + protected void isLearned(int iteration) { + if (learnRatio < 0F) { + return; + } + if (isLearned && iteration > 1) { + learnRatio = Math.abs(currentError) > Math.abs(totalError) ? learnRatio * 1.05F : learnRatio * 0.5F; + } else if (learnDecay > 0 && learnDecay < 1) { + learnRatio *= learnDecay; + } + // limit to max-learn-rate after update + if (learnLimit > 0 && learnRatio > learnLimit) { + learnRatio = learnLimit; + } + } + + @Deprecated + // TODO 此方法准备取消,利用向量的有序性代替 + protected List getUserItemSet(SparseMatrix sparseMatrix) { + List userItemSet = new ArrayList<>(userSize); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = sparseMatrix.getRowVector(userIndex); + IntSet indexes = new IntOpenHashSet(); + for (int position = 0, size = userVector.getElementSize(); position < size; position++) { + indexes.add(userVector.getIndex(position)); + } + userItemSet.add(indexes); + } + return userItemSet; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/Model.java b/src/main/java/com/jstarcraft/rns/model/Model.java new file mode 100644 index 0000000..358d9a1 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/Model.java @@ -0,0 +1,52 @@ +package com.jstarcraft.rns.model; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.core.common.option.Option; + +/** + * 推荐器 + * + *
+ * 注意区分每个阶段的职责:
+ * 准备阶段关注数据,负责根据算法转换数据;
+ * 训练阶段关注参数,负责根据参数获得模型;
+ * 预测阶段关注模型,负责根据模型预测得分;
+ * 
+ * + * @author Birdy + * + */ +public interface Model { + + /** + * 准备 + * + * @param configurator + */ + void prepare(Option configurator, DataModule module, DataSpace space); + // void prepare(Configuration configuration, SparseTensor trainTensor, + // SparseTensor testTensor, DataSpace storage); + + /** + * 训练 + * + * @param trainTensor + * @param testTensor + * @param contextModels + */ + void practice(); + + /** + * 预测 + * + * @param userIndex + * @param itemIndex + * @param featureIndexes + * @return + */ + void predict(DataInstance instance); + // double predict(int userIndex, int itemIndex, int... featureIndexes); + +} diff --git a/src/main/java/com/jstarcraft/rns/model/NeuralNetworkModel.java b/src/main/java/com/jstarcraft/rns/model/NeuralNetworkModel.java new file mode 100644 index 0000000..a02d668 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/NeuralNetworkModel.java @@ -0,0 +1,100 @@ +package com.jstarcraft.rns.model; + +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.nd4j.linalg.api.ndarray.INDArray; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.core.common.option.Option; + +/** + * 神经网络推荐器 + * + * @author Birdy + * + */ +public abstract class NeuralNetworkModel extends EpocheModel { + + /** + * the dimension of input units + */ + protected int inputDimension; + + /** + * the dimension of hidden units + */ + protected int hiddenDimension; + + /** + * the activation function of the hidden layer in the neural network + */ + protected String hiddenActivation; + + /** + * the activation function of the output layer in the neural network + */ + protected String outputActivation; + + /** + * the learning rate of the optimization algorithm + */ + protected float learnRatio; + + /** + * the momentum of the optimization algorithm + */ + protected float momentum; + + /** + * the regularization coefficient of the weights in the neural network + */ + protected float weightRegularization; + + /** + * the data structure that stores the training data + */ + protected INDArray inputData; + + /** + * the data structure that stores the predicted data + */ + protected INDArray outputData; + + protected MultiLayerNetwork network; + + protected abstract int getInputDimension(); + + protected abstract MultiLayerConfiguration getNetworkConfiguration(); + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + inputDimension = getInputDimension(); + hiddenDimension = configuration.getInteger("recommender.hidden.dimension"); + hiddenActivation = configuration.getString("recommender.hidden.activation"); + outputActivation = configuration.getString("recommender.output.activation"); + learnRatio = configuration.getFloat("recommender.iterator.learnrate"); + momentum = configuration.getFloat("recommender.iterator.momentum"); + weightRegularization = configuration.getFloat("recommender.weight.regularization"); + } + + @Override + protected void doPractice() { + MultiLayerConfiguration configuration = getNetworkConfiguration(); + network = new MultiLayerNetwork(configuration); + network.init(); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + network.fit(inputData, inputData); + totalError = (float) network.score(); + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + + outputData = network.output(inputData); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/ProbabilisticGraphicalModel.java b/src/main/java/com/jstarcraft/rns/model/ProbabilisticGraphicalModel.java new file mode 100644 index 0000000..1f31d5f --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/ProbabilisticGraphicalModel.java @@ -0,0 +1,149 @@ +package com.jstarcraft.rns.model; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.StringUtility; +import com.jstarcraft.rns.model.exception.ModelException; + +import it.unimi.dsi.fastutil.floats.Float2IntLinkedOpenHashMap; +import it.unimi.dsi.fastutil.floats.FloatRBTreeSet; +import it.unimi.dsi.fastutil.floats.FloatSet; + +/** + * 概率图推荐器 + * + * @author Birdy + * + */ +public abstract class ProbabilisticGraphicalModel extends EpocheModel { + + /** + * burn-in period + */ + protected int burnIn; + + /** + * size of statistics + */ + protected int numberOfStatistics = 0; + + /** + * number of topics + */ + protected int factorSize; + + /** 分数索引 (TODO 考虑取消或迁移.本质为连续特征离散化) */ + protected Float2IntLinkedOpenHashMap scoreIndexes; + + protected int scoreSize; + + /** + * sample lag (if -1 only one sample taken) + */ + protected int sampleSize; + + /** + * setup init member method + * + * @throws ModelException if error occurs during setting up + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + factorSize = configuration.getInteger("recommender.topic.number", 10); + burnIn = configuration.getInteger("recommender.pgm.burnin", 100); + sampleSize = configuration.getInteger("recommender.pgm.samplelag", 100); + + // TODO 此处会与scoreIndexes一起重构,本质为连续特征离散化. + FloatSet scores = new FloatRBTreeSet(); + for (MatrixScalar term : scoreMatrix) { + scores.add(term.getValue()); + } + scores.remove(0F); + scoreIndexes = new Float2IntLinkedOpenHashMap(); + int index = 0; + for (float score : scores) { + scoreIndexes.put(score, index++); + } + scoreSize = scoreIndexes.size(); + } + + @Override + protected void doPractice() { + long now = System.currentTimeMillis(); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + // E-step: infer parameters + eStep(); + if (logger.isInfoEnabled()) { + String message = StringUtility.format("eStep time is {}", System.currentTimeMillis() - now); + now = System.currentTimeMillis(); + logger.info(message); + } + + // M-step: update hyper-parameters + mStep(); + if (logger.isInfoEnabled()) { + String message = StringUtility.format("mStep time is {}", System.currentTimeMillis() - now); + now = System.currentTimeMillis(); + logger.info(message); + } + // get statistics after burn-in + if ((epocheIndex > burnIn) && (epocheIndex % sampleSize == 0)) { + readoutParameters(); + if (logger.isInfoEnabled()) { + String message = StringUtility.format("readoutParams time is {}", System.currentTimeMillis() - now); + now = System.currentTimeMillis(); + logger.info(message); + } + estimateParameters(); + if (logger.isInfoEnabled()) { + String message = StringUtility.format("estimateParams time is {}", System.currentTimeMillis() - now); + now = System.currentTimeMillis(); + logger.info(message); + } + } + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + // retrieve posterior probability distributions + estimateParameters(); + if (logger.isInfoEnabled()) { + String message = StringUtility.format("estimateParams time is {}", System.currentTimeMillis() - now); + now = System.currentTimeMillis(); + logger.info(message); + } + } + + protected boolean isConverged(int iter) { + return false; + } + + /** + * parameters estimation: used in the training phase + */ + protected abstract void eStep(); + + /** + * update the hyper-parameters + */ + protected abstract void mStep(); + + /** + * read out parameters for each iteration + */ + protected void readoutParameters() { + + } + + /** + * estimate the model parameters + */ + protected void estimateParameters() { + + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/SocialModel.java b/src/main/java/com/jstarcraft/rns/model/SocialModel.java new file mode 100644 index 0000000..e33ae78 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/SocialModel.java @@ -0,0 +1,87 @@ +package com.jstarcraft.rns.model; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.core.common.option.Option; + +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * 社交推荐器 + * + *
+ * 注意:基缘,是指构成人际关系的最基本的因素,包括血缘,地缘,业缘,趣缘.
+ * 实际业务使用过程中要注意人与人之间社区关系(趣缘)与社会关系(血缘,地缘,业缘)的区分.
+ * 
+ * + * @author Birdy + * + */ +public abstract class SocialModel extends MatrixFactorizationModel { + + protected String trusterField, trusteeField, coefficientField; + + protected int trusterDimension, trusteeDimension, coefficientDimension; + /** + * socialMatrix: social rate matrix, indicating a user is connecting to a number + * of other users + */ + protected SparseMatrix socialMatrix; + + /** + * social regularization + */ + protected float socialRegularization; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + socialRegularization = configuration.getFloat("recommender.social.regularization", 0.01f); + // social path for the socialMatrix + // TODO 此处是不是应该使用context.getSimilarity().getSimilarityMatrix();代替? + DataModule socialModel = space.getModule("social"); + // TODO 此处需要重构,trusterDimension与trusteeDimension要配置 + coefficientField = configuration.getString("data.model.fields.coefficient"); + trusterDimension = socialModel.getQualityInner(userField) + 0; + trusteeDimension = socialModel.getQualityInner(userField) + 1; + coefficientDimension = socialModel.getQuantityInner(coefficientField); + HashMatrix matrix = new HashMatrix(true, userSize, userSize, new Long2FloatRBTreeMap()); + for (DataInstance instance : socialModel) { + matrix.setValue(instance.getQualityFeature(trusterDimension), instance.getQualityFeature(trusteeDimension), instance.getQuantityFeature(coefficientDimension)); + } + socialMatrix = SparseMatrix.valueOf(userSize, userSize, matrix); + } + + /** + * 逆态化 + * + *
+     * 把数值从(0,1)转换为(minimumOfScore,maximumOfScore)
+     * 
+ * + * @param value + * @return + */ + protected float denormalize(float value) { + return minimumScore + value * (maximumScore - minimumScore); + } + + /** + * 正态化 + * + *
+     * 把数值从(minimumOfScore,maximumOfScore)转换为(0,1)
+     * 
+ * + * @param value + * @return + */ + protected float normalize(float value) { + return (value - minimumScore) / (maximumScore - minimumScore); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/benchmark/RandomGuessModel.java b/src/main/java/com/jstarcraft/rns/model/benchmark/RandomGuessModel.java new file mode 100644 index 0000000..54baf54 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/benchmark/RandomGuessModel.java @@ -0,0 +1,34 @@ +package com.jstarcraft.rns.model.benchmark; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.modem.ModemDefinition; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.AbstractModel; + +/** + * + * Random Guess推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +@ModemDefinition(value = { "userDimension", "itemDimension", "numberOfItems", "minimumOfScore", "maximumOfScore" }) +public class RandomGuessModel extends AbstractModel { + + @Override + protected void doPractice() { + } + + @Override + public synchronized void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + RandomUtility.setSeed(userIndex * itemSize + itemIndex); + instance.setQuantityMark(RandomUtility.randomFloat(minimumScore, maximumScore)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/benchmark/ranking/MostPopularModel.java b/src/main/java/com/jstarcraft/rns/model/benchmark/ranking/MostPopularModel.java new file mode 100644 index 0000000..a355917 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/benchmark/ranking/MostPopularModel.java @@ -0,0 +1,45 @@ +package com.jstarcraft.rns.model.benchmark.ranking; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.modem.ModemDefinition; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.AbstractModel; + +/** + * + * Most Popular推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +@ModemDefinition(value = { "itemDimension", "populars" }) +public class MostPopularModel extends AbstractModel { + + private int[] populars; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + populars = new int[itemSize]; + } + + @Override + protected void doPractice() { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + populars[itemIndex] = scoreMatrix.getColumnScope(itemIndex); + } + } + + @Override + public void predict(DataInstance instance) { + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(populars[itemIndex]); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/benchmark/rating/ConstantGuessModel.java b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/ConstantGuessModel.java new file mode 100644 index 0000000..1c22e08 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/ConstantGuessModel.java @@ -0,0 +1,44 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.modem.ModemDefinition; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.AbstractModel; + +/** + * + * Constant Guess推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +@ModemDefinition(value = { "constant" }) +public class ConstantGuessModel extends AbstractModel { + + private float constant; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // 默认使用最高最低分的平均值 + constant = (minimumScore + maximumScore) / 2F; + // TODO 支持配置分数 + constant = configuration.getFloat("recommend.constant-guess.score", constant); + } + + @Override + protected void doPractice() { + } + + @Override + public void predict(DataInstance instance) { + instance.setQuantityMark(constant); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/benchmark/rating/GlobalAverageModel.java b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/GlobalAverageModel.java new file mode 100644 index 0000000..82dae32 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/GlobalAverageModel.java @@ -0,0 +1,30 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.modem.ModemDefinition; +import com.jstarcraft.rns.model.AbstractModel; + +/** + * + * Global Average推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +@ModemDefinition(value = { "meanOfScore" }) +public class GlobalAverageModel extends AbstractModel { + + @Override + protected void doPractice() { + } + + @Override + public void predict(DataInstance instance) { + instance.setQuantityMark(meanScore); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/benchmark/rating/ItemAverageModel.java b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/ItemAverageModel.java new file mode 100644 index 0000000..d7895d6 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/ItemAverageModel.java @@ -0,0 +1,48 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.modem.ModemDefinition; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.AbstractModel; + +/** + * + * Item Average推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +@ModemDefinition(value = { "itemDimension", "itemMeans" }) +public class ItemAverageModel extends AbstractModel { + + /** 物品平均分数 */ + private float[] itemMeans; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + itemMeans = new float[itemSize]; + } + + @Override + protected void doPractice() { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + itemMeans[itemIndex] = itemVector.getElementSize() == 0 ? meanScore : itemVector.getSum(false) / itemVector.getElementSize(); + } + } + + @Override + public void predict(DataInstance instance) { + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(itemMeans[itemIndex]); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/benchmark/rating/ItemClusterModel.java b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/ItemClusterModel.java new file mode 100644 index 0000000..6ea1102 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/ItemClusterModel.java @@ -0,0 +1,183 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import java.util.Map.Entry; + +import org.apache.commons.math3.util.FastMath; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.ai.modem.ModemDefinition; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; + +/** + * + * Item Cluster推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +@ModemDefinition(value = { "userDimension", "itemDimension", "itemTopicProbabilities", "numberOfFactors", "scoreIndexes", "topicScoreMatrix" }) +public class ItemClusterModel extends ProbabilisticGraphicalModel { + + /** 物品的每评分次数 */ + private DenseMatrix itemScoreMatrix; // Nur + /** 物品的总评分次数 */ + private DenseVector itemScoreVector; // Nu + + /** 主题的每评分概率 */ + private DenseMatrix topicScoreMatrix; // Pkr + /** 主题的总评分概率 */ + private DenseVector topicScoreVector; // Pi + + /** 物品主题概率映射 */ + private DenseMatrix itemTopicProbabilities; // Gamma_(u,k) + + @Override + protected boolean isConverged(int iter) { + // TODO 需要重构 + float loss = 0F; + for (int i = 0; i < itemSize; i++) { + for (int k = 0; k < factorSize; k++) { + float rik = itemTopicProbabilities.getValue(i, k); + float pi_k = topicScoreVector.getValue(k); + + float sum_nl = 0F; + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + float nir = itemScoreMatrix.getValue(i, scoreIndex); + float pkr = topicScoreMatrix.getValue(k, scoreIndex); + + sum_nl += nir * Math.log(pkr); + } + + loss += rik * (Math.log(pi_k) + sum_nl); + } + } + float deltaLoss = (float) (loss - currentError); + if (iter > 1 && (deltaLoss > 0 || Float.isNaN(deltaLoss))) { + return true; + } + currentError = loss; + return false; + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + topicScoreMatrix = DenseMatrix.valueOf(factorSize, scoreSize); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + DenseVector probabilityVector = topicScoreMatrix.getRowVector(topicIndex); + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + float value = scalar.getValue(); + scalar.setValue(RandomUtility.randomInteger(scoreSize) + 1); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + topicScoreVector = DenseVector.valueOf(factorSize); + topicScoreVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(factorSize) + 1); + }); + topicScoreVector.scaleValues(1F / topicScoreVector.getSum(false)); + // TODO + topicScoreMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((float) Math.log(scalar.getValue())); + }); + topicScoreVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((float) Math.log(scalar.getValue())); + }); + + itemScoreMatrix = DenseMatrix.valueOf(itemSize, scoreSize); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector scoreVector = scoreMatrix.getColumnVector(itemIndex); + for (VectorScalar term : scoreVector) { + float score = term.getValue(); + int scoreIndex = scoreIndexes.get(score); + itemScoreMatrix.shiftValue(itemIndex, scoreIndex, 1); + } + } + itemScoreVector = DenseVector.valueOf(itemSize); + itemScoreVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(scoreMatrix.getColumnVector(scalar.getIndex()).getElementSize()); + }); + currentError = Float.MIN_VALUE; + + itemTopicProbabilities = DenseMatrix.valueOf(itemSize, factorSize); + } + + @Override + protected void eStep() { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + DenseVector probabilityVector = itemTopicProbabilities.getRowVector(itemIndex); + SparseVector scoreVector = scoreMatrix.getColumnVector(itemIndex); + if (scoreVector.getElementSize() == 0) { + probabilityVector.copyVector(topicScoreVector); + } else { + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float topicProbability = topicScoreVector.getValue(index); + for (VectorScalar term : scoreVector) { + int scoreIndex = scoreIndexes.get(term.getValue()); + float scoreProbability = topicScoreMatrix.getValue(index, scoreIndex); + topicProbability = topicProbability + scoreProbability; + } + scalar.setValue(topicProbability); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + } + } + + @Override + protected void mStep() { + topicScoreVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + float numerator = 0F, denorminator = 0F; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + float probability = (float) FastMath.exp(itemTopicProbabilities.getValue(itemIndex, index)); + numerator += probability * itemScoreMatrix.getValue(itemIndex, scoreIndex); + denorminator += probability * itemScoreVector.getValue(itemIndex); + } + float probability = (numerator / denorminator); + topicScoreMatrix.setValue(index, scoreIndex, probability); + } + float sumProbability = 0F; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + float probability = (float) FastMath.exp(itemTopicProbabilities.getValue(itemIndex, index)); + sumProbability += probability; + } + scalar.setValue(sumProbability); + }); + topicScoreVector.scaleValues(1F / topicScoreVector.getSum(false)); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float topicProbability = itemTopicProbabilities.getValue(itemIndex, topicIndex); // probability + float topicValue = 0F; + for (Entry entry : scoreIndexes.entrySet()) { + float score = entry.getKey(); + float probability = topicScoreMatrix.getValue(topicIndex, entry.getValue()); + topicValue += score * probability; + } + value += topicProbability * topicValue; + } + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/benchmark/rating/UserAverageModel.java b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/UserAverageModel.java new file mode 100644 index 0000000..a671be1 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/UserAverageModel.java @@ -0,0 +1,48 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.modem.ModemDefinition; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.AbstractModel; + +/** + * + * User Average推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +@ModemDefinition(value = { "userDimension", "userMeans" }) +public class UserAverageModel extends AbstractModel { + + /** 用户平均分数 */ + private float[] userMeans; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + userMeans = new float[userSize]; + } + + @Override + protected void doPractice() { + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + userMeans[userIndex] = userVector.getElementSize() == 0 ? meanScore : userVector.getSum(false) / userVector.getElementSize(); + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + instance.setQuantityMark(userMeans[userIndex]); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/benchmark/rating/UserClusterModel.java b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/UserClusterModel.java new file mode 100644 index 0000000..3440155 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/benchmark/rating/UserClusterModel.java @@ -0,0 +1,186 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import java.util.Map.Entry; + +import org.apache.commons.math3.util.FastMath; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.ai.modem.ModemDefinition; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; + +/** + * + * User Cluster推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +@ModemDefinition(value = { "userDimension", "itemDimension", "userTopicProbabilities", "numberOfFactors", "scoreIndexes", "topicScoreMatrix" }) +public class UserClusterModel extends ProbabilisticGraphicalModel { + + /** 用户的每评分次数 */ + private DenseMatrix userScoreMatrix; // Nur + /** 用户的总评分次数 */ + private DenseVector userScoreVector; // Nu + + /** 主题的每评分概率 */ + private DenseMatrix topicScoreMatrix; // Pkr + /** 主题的总评分概率 */ + private DenseVector topicScoreVector; // Pi + + /** 用户主题概率映射 */ + private DenseMatrix userTopicProbabilities; // Gamma_(u,k) + + @Override + protected boolean isConverged(int iter) { + // TODO 需要重构 + float loss = 0F; + + for (int u = 0; u < userSize; u++) { + for (int k = 0; k < factorSize; k++) { + float ruk = userTopicProbabilities.getValue(u, k); + float pi_k = topicScoreVector.getValue(k); + + float sum_nl = 0F; + for (int r = 0; r < scoreIndexes.size(); r++) { + float nur = userScoreMatrix.getValue(u, r); + float pkr = topicScoreMatrix.getValue(k, r); + + sum_nl += nur * Math.log(pkr); + } + + loss += ruk * (Math.log(pi_k) + sum_nl); + } + } + + float deltaLoss = (float) (loss - currentError); + + if (iter > 1 && (deltaLoss > 0 || Float.isNaN(deltaLoss))) { + return true; + } + + currentError = loss; + return false; + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + topicScoreMatrix = DenseMatrix.valueOf(factorSize, scoreSize); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + DenseVector probabilityVector = topicScoreMatrix.getRowVector(topicIndex); + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(scoreSize) + 1); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + topicScoreVector = DenseVector.valueOf(factorSize); + topicScoreVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(factorSize) + 1); + }); + topicScoreVector.scaleValues(1F / topicScoreVector.getSum(false)); + // TODO + topicScoreMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((float) Math.log(scalar.getValue())); + }); + topicScoreVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((float) Math.log(scalar.getValue())); + }); + + userScoreMatrix = DenseMatrix.valueOf(userSize, scoreSize); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector scoreVector = scoreMatrix.getRowVector(userIndex); + for (VectorScalar term : scoreVector) { + float score = term.getValue(); + int scoreIndex = scoreIndexes.get(score); + userScoreMatrix.shiftValue(userIndex, scoreIndex, 1); + } + } + userScoreVector = DenseVector.valueOf(userSize); + userScoreVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(scoreMatrix.getRowVector(scalar.getIndex()).getElementSize()); + }); + currentError = Float.MIN_VALUE; + + userTopicProbabilities = DenseMatrix.valueOf(userSize, factorSize); + } + + @Override + protected void eStep() { + for (int userIndex = 0; userIndex < userSize; userIndex++) { + DenseVector probabilityVector = userTopicProbabilities.getRowVector(userIndex); + SparseVector scoreVector = scoreMatrix.getRowVector(userIndex); + if (scoreVector.getElementSize() == 0) { + probabilityVector.copyVector(topicScoreVector); + } else { + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float topicProbability = topicScoreVector.getValue(index); + for (VectorScalar term : scoreVector) { + int scoreIndex = scoreIndexes.get(term.getValue()); + float scoreProbability = topicScoreMatrix.getValue(index, scoreIndex); + topicProbability = topicProbability + scoreProbability; + } + scalar.setValue(topicProbability); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + } + } + + @Override + protected void mStep() { + topicScoreVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + float numerator = 0F, denorminator = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + float probability = (float) FastMath.exp(userTopicProbabilities.getValue(userIndex, index)); + numerator += probability * userScoreMatrix.getValue(userIndex, scoreIndex); + denorminator += probability * userScoreVector.getValue(userIndex); + } + float probability = (numerator / denorminator); + topicScoreMatrix.setValue(index, scoreIndex, probability); + } + float sumProbability = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + float probability = (float) FastMath.exp(userTopicProbabilities.getValue(userIndex, index)); + sumProbability += probability; + } + scalar.setValue(sumProbability); + }); + topicScoreVector.scaleValues(1F / topicScoreVector.getSum(false)); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float topicProbability = userTopicProbabilities.getValue(userIndex, topicIndex); + float topicValue = 0F; + for (Entry entry : scoreIndexes.entrySet()) { + float score = entry.getKey(); + float probability = topicScoreMatrix.getValue(topicIndex, entry.getValue()); + topicValue += score * probability; + } + value += topicProbability * topicValue; + } + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/BHFreeModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/BHFreeModel.java new file mode 100644 index 0000000..3bdb88d --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/BHFreeModel.java @@ -0,0 +1,277 @@ +package com.jstarcraft.rns.model.collaborative; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.MathCell; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.table.SparseTable; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; +import com.jstarcraft.rns.utility.SampleUtility; + +import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap; + +/** + * + * BH Free推荐器 + * + *
+ * Balancing Prediction and Recommendation Accuracy: Hierarchical Latent Factors for Preference Data
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public abstract class BHFreeModel extends ProbabilisticGraphicalModel { + + private static class TopicTerm { + + private int userTopic; + + private int itemTopic; + + private int scoreIndex; + + private TopicTerm(int userTopic, int itemTopic, int scoreIndex) { + this.userTopic = userTopic; + this.itemTopic = itemTopic; + this.scoreIndex = scoreIndex; + } + + void update(int userTopic, int itemTopic) { + this.userTopic = userTopic; + this.itemTopic = itemTopic; + } + + public int getUserTopic() { + return userTopic; + } + + public int getItemTopic() { + return itemTopic; + } + + public int getScoreIndex() { + return scoreIndex; + } + + } + + private SparseTable topicMatrix; + + private float initGamma, initSigma, initAlpha, initBeta; + + /** + * number of user communities + */ + protected int userTopicSize; // K + + /** + * number of item categories + */ + protected int itemTopicSize; // L + + /** + * evaluation of the user u which have been assigned to the user topic k + */ + private DenseMatrix user2TopicNumbers; + + /** + * observations for the user + */ + private DenseVector userNumbers; + + /** + * observations associated with community k + */ + private DenseVector userTopicNumbers; + + /** + * number of user communities * number of topics + */ + private DenseMatrix userTopic2ItemTopicNumbers; // Nkl + + /** + * number of user communities * number of topics * number of ratings + */ + private int[][][] userTopic2ItemTopicScoreNumbers, userTopic2ItemTopicItemNumbers; // Nklr, + // Nkli; + + // parameters + protected DenseMatrix user2TopicProbabilities, userTopic2ItemTopicProbabilities; + protected DenseMatrix user2TopicSums, userTopic2ItemTopicSums; + protected double[][][] userTopic2ItemTopicScoreProbabilities, userTopic2ItemTopicItemProbabilities; + protected double[][][] userTopic2ItemTopicScoreSums, userTopic2ItemTopicItemSums; + + private DenseMatrix topicProbabilities; + private DenseVector userProbabilities; + private DenseVector itemProbabilities; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + userTopicSize = configuration.getInteger("recommender.bhfree.user.topic.number", 10); + itemTopicSize = configuration.getInteger("recommender.bhfree.item.topic.number", 10); + initAlpha = configuration.getFloat("recommender.bhfree.alpha", 1.0f / userTopicSize); + initBeta = configuration.getFloat("recommender.bhfree.beta", 1.0f / itemTopicSize); + initGamma = configuration.getFloat("recommender.bhfree.gamma", 1.0f / scoreSize); + initSigma = configuration.getFloat("recommender.sigma", 1.0f / itemSize); + scoreSize = scoreIndexes.size(); + + // TODO 考虑重构(整合为UserTopic对象) + user2TopicNumbers = DenseMatrix.valueOf(userSize, userTopicSize); + userNumbers = DenseVector.valueOf(userSize); + + userTopic2ItemTopicNumbers = DenseMatrix.valueOf(userTopicSize, itemTopicSize); + userTopicNumbers = DenseVector.valueOf(userTopicSize); + + userTopic2ItemTopicScoreNumbers = new int[userTopicSize][itemTopicSize][scoreSize]; + userTopic2ItemTopicItemNumbers = new int[userTopicSize][itemTopicSize][itemSize]; + + topicMatrix = new SparseTable<>(true, userSize, itemSize, new Int2ObjectRBTreeMap<>()); + + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + int scoreIndex = scoreIndexes.get(score); + int userTopic = RandomUtility.randomInteger(userTopicSize); // user's + // topic + // k + int itemTopic = RandomUtility.randomInteger(itemTopicSize); // item's + // topic + // l + + user2TopicNumbers.shiftValue(userIndex, userTopic, 1F); + userNumbers.shiftValue(userIndex, 1F); + userTopic2ItemTopicNumbers.shiftValue(userTopic, itemTopic, 1F); + userTopicNumbers.shiftValue(userTopic, 1F); + userTopic2ItemTopicScoreNumbers[userTopic][itemTopic][scoreIndex]++; + userTopic2ItemTopicItemNumbers[userTopic][itemTopic][itemIndex]++; + TopicTerm topic = new TopicTerm(userTopic, itemTopic, scoreIndex); + topicMatrix.setValue(userIndex, itemIndex, topic); + } + + // parameters + // TODO 考虑重构为一个对象 + user2TopicSums = DenseMatrix.valueOf(userSize, userTopicSize); + userTopic2ItemTopicSums = DenseMatrix.valueOf(userTopicSize, itemTopicSize); + userTopic2ItemTopicScoreSums = new double[userTopicSize][itemTopicSize][scoreSize]; + userTopic2ItemTopicScoreProbabilities = new double[userTopicSize][itemTopicSize][scoreSize]; + userTopic2ItemTopicItemSums = new double[userTopicSize][itemTopicSize][itemSize]; + userTopic2ItemTopicItemProbabilities = new double[userTopicSize][itemTopicSize][itemSize]; + + topicProbabilities = DenseMatrix.valueOf(userTopicSize, itemTopicSize); + userProbabilities = DenseVector.valueOf(userTopicSize); + itemProbabilities = DenseVector.valueOf(itemTopicSize); + } + + @Override + protected void eStep() { + for (MathCell term : topicMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + TopicTerm topicTerm = term.getValue(); + int scoreIndex = topicTerm.getScoreIndex(); + int userTopic = topicTerm.getUserTopic(); + int itemTopic = topicTerm.getItemTopic(); + + user2TopicNumbers.shiftValue(userIndex, userTopic, -1F); + userNumbers.shiftValue(userIndex, -1F); + userTopic2ItemTopicNumbers.shiftValue(userTopic, itemTopic, -1F); + userTopicNumbers.shiftValue(userTopic, -1F); + userTopic2ItemTopicScoreNumbers[userTopic][itemTopic][scoreIndex]--; + userTopic2ItemTopicItemNumbers[userTopic][itemTopic][itemIndex]--; + + // normalization + int userTopicIndex = userTopic; + int itemTopicIndex = itemTopic; + topicProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + float value = (user2TopicNumbers.getValue(userIndex, userTopicIndex) + initAlpha) / (userNumbers.getValue(userIndex) + userTopicSize * initAlpha); + value *= (userTopic2ItemTopicNumbers.getValue(userTopicIndex, itemTopicIndex) + initBeta) / (userTopicNumbers.getValue(userTopicIndex) + itemTopicSize * initBeta); + value *= (userTopic2ItemTopicScoreNumbers[userTopicIndex][itemTopicIndex][scoreIndex] + initGamma) / (userTopic2ItemTopicNumbers.getValue(userTopicIndex, itemTopicIndex) + scoreSize * initGamma); + value *= (userTopic2ItemTopicItemNumbers[userTopicIndex][itemTopicIndex][itemIndex] + initSigma) / (userTopic2ItemTopicNumbers.getValue(userTopicIndex, itemTopicIndex) + itemSize * initSigma); + scalar.setValue(value); + }); + + // 计算概率 + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + userProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = topicProbabilities.getRowVector(index).getSum(false); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + userTopic = SampleUtility.binarySearch(userProbabilities, 0, userProbabilities.getElementSize() - 1, RandomUtility.randomFloat(sum.getValue())); + sum.setValue(0F); + itemProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = topicProbabilities.getColumnVector(index).getSum(false); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + itemTopic = SampleUtility.binarySearch(itemProbabilities, 0, itemProbabilities.getElementSize() - 1, RandomUtility.randomFloat(sum.getValue())); + + topicTerm.update(userTopic, itemTopic); + // add statistic + user2TopicNumbers.shiftValue(userIndex, userTopic, 1F); + userNumbers.shiftValue(userIndex, 1F); + userTopic2ItemTopicNumbers.shiftValue(userTopic, itemTopic, 1F); + userTopicNumbers.shiftValue(userTopic, 1F); + userTopic2ItemTopicScoreNumbers[userTopic][itemTopic][scoreIndex]++; + userTopic2ItemTopicItemNumbers[userTopic][itemTopic][itemIndex]++; + } + + } + + @Override + protected void mStep() { + + } + + @Override + protected void readoutParameters() { + for (int userTopic = 0; userTopic < userTopicSize; userTopic++) { + for (int userIndex = 0; userIndex < userSize; userIndex++) { + user2TopicSums.shiftValue(userIndex, userTopic, (user2TopicNumbers.getValue(userIndex, userTopic) + initAlpha) / (userNumbers.getValue(userIndex) + userTopicSize * initAlpha)); + } + for (int itemTopic = 0; itemTopic < itemTopicSize; itemTopic++) { + userTopic2ItemTopicSums.shiftValue(userTopic, itemTopic, (userTopic2ItemTopicNumbers.getValue(userTopic, itemTopic) + initBeta) / (userTopicNumbers.getValue(userTopic) + itemTopicSize * initBeta)); + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + userTopic2ItemTopicScoreSums[userTopic][itemTopic][scoreIndex] += (userTopic2ItemTopicScoreNumbers[userTopic][itemTopic][scoreIndex] + initGamma) / (userTopic2ItemTopicNumbers.getValue(userTopic, itemTopic) + scoreSize * initGamma); + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + userTopic2ItemTopicItemSums[userTopic][itemTopic][itemIndex] += (userTopic2ItemTopicItemNumbers[userTopic][itemTopic][itemIndex] + initSigma) / (userTopic2ItemTopicNumbers.getValue(userTopic, itemTopic) + itemSize * initSigma); + } + } + } + numberOfStatistics++; + } + + @Override + protected void estimateParameters() { + float scale = 1F / numberOfStatistics; + user2TopicProbabilities = DenseMatrix.copyOf(user2TopicSums); + user2TopicProbabilities.scaleValues(scale); + userTopic2ItemTopicProbabilities = DenseMatrix.copyOf(userTopic2ItemTopicSums); + userTopic2ItemTopicProbabilities.scaleValues(scale); + for (int userTopic = 0; userTopic < userTopicSize; userTopic++) { + for (int itemTopic = 0; itemTopic < itemTopicSize; itemTopic++) { + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + userTopic2ItemTopicScoreProbabilities[userTopic][itemTopic][scoreIndex] = userTopic2ItemTopicScoreSums[userTopic][itemTopic][scoreIndex] * scale; + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + userTopic2ItemTopicItemProbabilities[userTopic][itemTopic][itemIndex] = userTopic2ItemTopicItemSums[userTopic][itemTopic][itemIndex] * scale; + } + } + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/BUCMModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/BUCMModel.java new file mode 100644 index 0000000..322aaf7 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/BUCMModel.java @@ -0,0 +1,370 @@ +package com.jstarcraft.rns.model.collaborative; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; +import com.jstarcraft.rns.utility.GammaUtility; +import com.jstarcraft.rns.utility.SampleUtility; + +import it.unimi.dsi.fastutil.ints.Int2IntRBTreeMap; + +/** + * + * BUCM推荐器 + * + *
+ * Bayesian User Community Model
+ * Modeling Item Selection and Relevance for Accurate Recommendations
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public abstract class BUCMModel extends ProbabilisticGraphicalModel { + /** + * number of occurrences of entry (t, i, r) + */ + private int[][][] topicItemScoreNumbers; + + /** + * number of occurrentces of entry (user, topic) + */ + private DenseMatrix userTopicNumbers; + + /** + * number of occurences of users + */ + private DenseVector userNumbers; + + /** + * number of occurrences of entry (topic, item) + */ + private DenseMatrix topicItemNumbers; + + /** + * number of occurrences of items + */ + private DenseVector topicNumbers; + + /** + * cumulative statistics of probabilities of (t, i, r) + */ + private float[][][] topicItemScoreSums; + + /** + * posterior probabilities of parameters epsilon_{k, i, r} + */ + protected float[][][] topicItemScoreProbabilities; + + /** + * P(k | u) + */ + protected DenseMatrix userTopicProbabilities, userTopicSums; + + /** + * P(i | k) + */ + protected DenseMatrix topicItemProbabilities, topicItemSums; + + /** + * + */ + private DenseVector alpha; + + /** + * + */ + private DenseVector beta; + + /** + * + */ + private DenseVector gamma; + + /** + * + */ + protected Int2IntRBTreeMap topicAssignments; + + private DenseVector probabilities; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // cumulative parameters + // TODO 考虑重构 + userTopicSums = DenseMatrix.valueOf(userSize, factorSize); + topicItemSums = DenseMatrix.valueOf(factorSize, itemSize); + topicItemScoreSums = new float[factorSize][itemSize][scoreSize]; + + // initialize count varialbes + userTopicNumbers = DenseMatrix.valueOf(userSize, factorSize); + userNumbers = DenseVector.valueOf(userSize); + + topicItemNumbers = DenseMatrix.valueOf(factorSize, itemSize); + topicNumbers = DenseVector.valueOf(factorSize); + + topicItemScoreNumbers = new int[factorSize][itemSize][scoreSize]; + + float initAlpha = configuration.getFloat("recommender.bucm.alpha", 1F / factorSize); + alpha = DenseVector.valueOf(factorSize); + alpha.setValues(initAlpha); + + float initBeta = configuration.getFloat("re.bucm.beta", 1F / itemSize); + beta = DenseVector.valueOf(itemSize); + beta.setValues(initBeta); + + float initGamma = configuration.getFloat("recommender.bucm.gamma", 1F / factorSize); + gamma = DenseVector.valueOf(scoreSize); + gamma.setValues(initGamma); + + // initialize topics + topicAssignments = new Int2IntRBTreeMap(); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + int scoreIndex = scoreIndexes.get(score); // rating level 0 ~ + // numLevels + int topicIndex = RandomUtility.randomInteger(factorSize); // 0 ~ + // k-1 + + // Assign a topic t to pair (u, i) + topicAssignments.put(userIndex * itemSize + itemIndex, topicIndex); + // for users + userTopicNumbers.shiftValue(userIndex, topicIndex, 1F); + userNumbers.shiftValue(userIndex, 1F); + + // for items + topicItemNumbers.shiftValue(topicIndex, itemIndex, 1F); + topicNumbers.shiftValue(topicIndex, 1F); + + // for ratings + topicItemScoreNumbers[topicIndex][itemIndex][scoreIndex]++; + } + + probabilities = DenseVector.valueOf(factorSize); + } + + @Override + protected void eStep() { + float alphaSum = alpha.getSum(false); + float betaSum = beta.getSum(false); + float gammaSum = gamma.getSum(false); + + // collapse Gibbs sampling + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + int scoreIndex = scoreIndexes.get(score); // rating level 0 ~ + // numLevels + int topicIndex = topicAssignments.get(userIndex * itemSize + itemIndex); + + // for user + userTopicNumbers.shiftValue(userIndex, topicIndex, -1F); + userNumbers.shiftValue(userIndex, -1F); + + // for item + topicItemNumbers.shiftValue(topicIndex, itemIndex, -1F); + topicNumbers.shiftValue(topicIndex, -1F); + + // for rating + topicItemScoreNumbers[topicIndex][itemIndex][scoreIndex]--; + + // do multinomial sampling via cumulative method: + // 计算概率 + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + probabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = (userTopicNumbers.getValue(userIndex, index) + alpha.getValue(index)) / (userNumbers.getValue(userIndex) + alphaSum); + value *= (topicItemNumbers.getValue(index, itemIndex) + beta.getValue(itemIndex)) / (topicNumbers.getValue(index) + betaSum); + value *= (topicItemScoreNumbers[index][itemIndex][scoreIndex] + gamma.getValue(scoreIndex)) / (topicItemNumbers.getValue(index, itemIndex) + gammaSum); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + topicIndex = SampleUtility.binarySearch(probabilities, 0, probabilities.getElementSize() - 1, RandomUtility.randomFloat(sum.getValue())); + + // new topic t + topicAssignments.put(userIndex * itemSize + itemIndex, topicIndex); + + // add newly estimated z_i to count variables + userTopicNumbers.shiftValue(userIndex, topicIndex, 1F); + userNumbers.shiftValue(userIndex, 1F); + + topicItemNumbers.shiftValue(topicIndex, itemIndex, 1F); + topicNumbers.shiftValue(topicIndex, 1F); + + topicItemScoreNumbers[topicIndex][itemIndex][scoreIndex]++; + } + } + + /** + * Thomas P. Minka, Estimating a Dirichlet distribution, see Eq.(55) + */ + @Override + protected void mStep() { + float denominator; + float value = 0F; + + // update alpha + float alphaValue; + float alphaSum = alpha.getSum(false); + float alphaDigamma = GammaUtility.digamma(alphaSum); + denominator = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + value = userNumbers.getValue(userIndex); + if (value != 0F) { + denominator += GammaUtility.digamma(value + alphaSum) - alphaDigamma; + } + } + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + alphaValue = alpha.getValue(topicIndex); + alphaDigamma = GammaUtility.digamma(alphaValue); + float numerator = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + value = userTopicNumbers.getValue(userIndex, topicIndex); + if (value != 0F) { + numerator += GammaUtility.digamma(value + alphaValue) - alphaDigamma; + } + } + if (numerator != 0F) { + alpha.setValue(topicIndex, alphaValue * (numerator / denominator)); + } + } + + // update beta + float betaValue; + float bataSum = beta.getSum(false); + float betaDigamma = GammaUtility.digamma(bataSum); + denominator = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = topicNumbers.getValue(topicIndex); + if (value != 0F) { + denominator += GammaUtility.digamma(value + bataSum) - betaDigamma; + } + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + betaValue = beta.getValue(itemIndex); + betaDigamma = GammaUtility.digamma(betaValue); + float numerator = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = topicItemNumbers.getValue(topicIndex, itemIndex); + if (value != 0F) { + numerator += GammaUtility.digamma(value + betaValue) - betaDigamma; + } + } + if (numerator != 0F) { + beta.setValue(itemIndex, betaValue * (numerator / denominator)); + } + } + + // update gamma + float gammaValue; + float gammaSum = gamma.getSum(false); + float gammaDigamma = GammaUtility.digamma(gammaSum); + denominator = 0F; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = topicItemNumbers.getValue(topicIndex, itemIndex); + if (value != 0F) { + denominator += GammaUtility.digamma(value + gammaSum) - gammaDigamma; + } + } + } + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + gammaValue = gamma.getValue(scoreIndex); + gammaDigamma = GammaUtility.digamma(gammaValue); + float numerator = 0F; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = topicItemScoreNumbers[topicIndex][itemIndex][scoreIndex]; + if (value != 0F) { + numerator += GammaUtility.digamma(value + gammaValue) - gammaDigamma; + } + } + } + if (numerator != 0F) { + gamma.setValue(scoreIndex, gammaValue * (numerator / denominator)); + } + } + } + + @Override + protected boolean isConverged(int iter) { + float loss = 0F; + // get params + estimateParameters(); + // compute likelihood + int sum = 0; + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + int scoreIndex = scoreIndexes.get(score); + float probability = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + probability += userTopicProbabilities.getValue(userIndex, topicIndex) * topicItemProbabilities.getValue(topicIndex, itemIndex) * topicItemScoreProbabilities[topicIndex][itemIndex][scoreIndex]; + } + loss += (float) -Math.log(probability); + sum++; + } + loss /= sum; + float delta = loss - currentError; // loss gets smaller, delta <= 0 + if (numberOfStatistics > 1 && delta > 0) { + return true; + } + currentError = loss; + return false; + } + + protected void readoutParameters() { + float value; + float sumAlpha = alpha.getSum(false); + float sumBeta = beta.getSum(false); + float sumGamma = gamma.getSum(false); + + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + for (int userIndex = 0; userIndex < userSize; userIndex++) { + value = (userTopicNumbers.getValue(userIndex, topicIndex) + alpha.getValue(topicIndex)) / (userNumbers.getValue(userIndex) + sumAlpha); + userTopicSums.shiftValue(userIndex, topicIndex, value); + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + value = (topicItemNumbers.getValue(topicIndex, itemIndex) + beta.getValue(itemIndex)) / (topicNumbers.getValue(topicIndex) + sumBeta); + topicItemSums.shiftValue(topicIndex, itemIndex, value); + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + value = (topicItemScoreNumbers[topicIndex][itemIndex][scoreIndex] + gamma.getValue(scoreIndex)) / (topicItemNumbers.getValue(topicIndex, itemIndex) + sumGamma); + topicItemScoreSums[topicIndex][itemIndex][scoreIndex] += value; + } + } + } + numberOfStatistics++; + } + + @Override + protected void estimateParameters() { + userTopicProbabilities = DenseMatrix.copyOf(userTopicSums); + userTopicProbabilities.scaleValues(1F / numberOfStatistics); + topicItemProbabilities = DenseMatrix.copyOf(topicItemSums); + topicItemProbabilities.scaleValues(1F / numberOfStatistics); + + topicItemScoreProbabilities = new float[factorSize][itemSize][scoreSize]; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + topicItemScoreProbabilities[topicIndex][itemIndex][scoreIndex] = topicItemScoreSums[topicIndex][itemIndex][scoreIndex] / numberOfStatistics; + } + } + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ItemKNNModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ItemKNNModel.java new file mode 100644 index 0000000..196973a --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ItemKNNModel.java @@ -0,0 +1,134 @@ +package com.jstarcraft.rns.model.collaborative; + +import java.util.Collection; +import java.util.Comparator; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.algorithm.correlation.MathCorrelation; +import com.jstarcraft.ai.math.structure.vector.ArrayVector; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.common.reflection.ReflectionUtility; +import com.jstarcraft.core.utility.Integer2FloatKeyValue; +import com.jstarcraft.core.utility.Neighborhood; +import com.jstarcraft.rns.model.AbstractModel; + +import it.unimi.dsi.fastutil.ints.Int2FloatMap; +import it.unimi.dsi.fastutil.ints.Int2FloatRBTreeMap; +import it.unimi.dsi.fastutil.ints.Int2FloatSortedMap; + +/** + * + * Item KNN推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public abstract class ItemKNNModel extends AbstractModel { + + /** 邻居数量 */ + private int neighborSize; + + protected DenseVector itemMeans; + + /** + * item's nearest neighbors for kNN > 0 + */ + protected MathVector[] itemNeighbors; + + protected SparseVector[] userVectors; + + protected SparseVector[] itemVectors; + + private Comparator comparator = new Comparator() { + + @Override + public int compare(Integer2FloatKeyValue left, Integer2FloatKeyValue right) { + int compare = -(Float.compare(left.getValue(), right.getValue())); + if (compare == 0) { + compare = Integer.compare(left.getKey(), right.getKey()); + } + return compare; + } + + }; + + protected MathVector getNeighborVector(Collection neighbors) { + int size = neighbors.size(); + int[] indexes = new int[size]; + float[] values = new float[size]; + Int2FloatSortedMap keyValues = new Int2FloatRBTreeMap(); + for (Integer2FloatKeyValue term : neighbors) { + keyValues.put(term.getKey(), term.getValue()); + } + int cursor = 0; + for (Int2FloatMap.Entry term : keyValues.int2FloatEntrySet()) { + indexes[cursor] = term.getIntKey(); + values[cursor] = term.getFloatValue(); + cursor++; + } + return new ArrayVector(size, indexes, values); + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + neighborSize = configuration.getInteger("recommender.neighbors.knn.number", 50); + // TODO 设置容量 + itemNeighbors = new MathVector[itemSize]; + Neighborhood[] knns = new Neighborhood[itemSize]; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + knns[itemIndex] = new Neighborhood<>(neighborSize, comparator); + } + // TODO 修改为配置枚举 + try { + Class correlationClass = (Class) Class.forName(configuration.getString("recommender.correlation.class")); + MathCorrelation correlation = ReflectionUtility.getInstance(correlationClass); + correlation.calculateCoefficients(scoreMatrix, true, (leftIndex, rightIndex, coefficient) -> { + if (leftIndex == rightIndex) { + return; + } + // 忽略相似度为0的物品 + if (coefficient == 0F) { + return; + } + knns[leftIndex].updateNeighbor(new Integer2FloatKeyValue(rightIndex, coefficient)); + knns[rightIndex].updateNeighbor(new Integer2FloatKeyValue(leftIndex, coefficient)); + }); + } catch (Exception exception) { + throw new RuntimeException(exception); + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + itemNeighbors[itemIndex] = getNeighborVector(knns[itemIndex].getNeighbors()); + } + + itemMeans = DenseVector.valueOf(itemSize); + + userVectors = new SparseVector[userSize]; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + userVectors[userIndex] = scoreMatrix.getRowVector(userIndex); + } + + itemVectors = new SparseVector[itemSize]; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + itemVectors[itemIndex] = scoreMatrix.getColumnVector(itemIndex); + } + } + + @Override + protected void doPractice() { + meanScore = scoreMatrix.getSum(false) / scoreMatrix.getElementSize(); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + itemMeans.setValue(itemIndex, itemVector.getElementSize() > 0 ? itemVector.getSum(false) / itemVector.getElementSize() : meanScore); + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/UserKNNModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/UserKNNModel.java new file mode 100644 index 0000000..919a354 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/UserKNNModel.java @@ -0,0 +1,134 @@ +package com.jstarcraft.rns.model.collaborative; + +import java.util.Collection; +import java.util.Comparator; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.algorithm.correlation.MathCorrelation; +import com.jstarcraft.ai.math.structure.vector.ArrayVector; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.common.reflection.ReflectionUtility; +import com.jstarcraft.core.utility.Integer2FloatKeyValue; +import com.jstarcraft.core.utility.Neighborhood; +import com.jstarcraft.rns.model.AbstractModel; + +import it.unimi.dsi.fastutil.ints.Int2FloatMap; +import it.unimi.dsi.fastutil.ints.Int2FloatRBTreeMap; +import it.unimi.dsi.fastutil.ints.Int2FloatSortedMap; + +/** + * + * User KNN推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public abstract class UserKNNModel extends AbstractModel { + + /** 邻居数量 */ + private int neighborSize; + + protected DenseVector userMeans; + + /** + * user's nearest neighbors for kNN > 0 + */ + protected MathVector[] userNeighbors; + + protected SparseVector[] userVectors; + + protected SparseVector[] itemVectors; + + private Comparator comparator = new Comparator() { + + @Override + public int compare(Integer2FloatKeyValue left, Integer2FloatKeyValue right) { + int compare = -(Float.compare(left.getValue(), right.getValue())); + if (compare == 0) { + compare = Integer.compare(left.getKey(), right.getKey()); + } + return compare; + } + + }; + + protected MathVector getNeighborVector(Collection neighbors) { + int size = neighbors.size(); + int[] indexes = new int[size]; + float[] values = new float[size]; + Int2FloatSortedMap keyValues = new Int2FloatRBTreeMap(); + for (Integer2FloatKeyValue term : neighbors) { + keyValues.put(term.getKey(), term.getValue()); + } + int cursor = 0; + for (Int2FloatMap.Entry term : keyValues.int2FloatEntrySet()) { + indexes[cursor] = term.getIntKey(); + values[cursor] = term.getFloatValue(); + cursor++; + } + return new ArrayVector(size, indexes, values); + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + neighborSize = configuration.getInteger("recommender.neighbors.knn.number"); + // TODO 设置容量 + userNeighbors = new MathVector[userSize]; + Neighborhood[] knns = new Neighborhood[userSize]; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + knns[userIndex] = new Neighborhood<>(neighborSize, comparator); + } + // TODO 修改为配置枚举 + try { + Class correlationClass = (Class) Class.forName(configuration.getString("recommender.correlation.class")); + MathCorrelation correlation = ReflectionUtility.getInstance(correlationClass); + correlation.calculateCoefficients(scoreMatrix, false, (leftIndex, rightIndex, coefficient) -> { + if (leftIndex == rightIndex) { + return; + } + // 忽略相似度为0的物品 + if (coefficient == 0F) { + return; + } + knns[leftIndex].updateNeighbor(new Integer2FloatKeyValue(rightIndex, coefficient)); + knns[rightIndex].updateNeighbor(new Integer2FloatKeyValue(leftIndex, coefficient)); + }); + } catch (Exception exception) { + throw new RuntimeException(exception); + } + for (int userIndex = 0; userIndex < userSize; userIndex++) { + userNeighbors[userIndex] = getNeighborVector(knns[userIndex].getNeighbors()); + } + + userMeans = DenseVector.valueOf(userSize); + + userVectors = new SparseVector[userSize]; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + userVectors[userIndex] = scoreMatrix.getRowVector(userIndex); + } + + itemVectors = new SparseVector[itemSize]; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + itemVectors[itemIndex] = scoreMatrix.getColumnVector(itemIndex); + } + } + + @Override + protected void doPractice() { + meanScore = scoreMatrix.getSum(false) / scoreMatrix.getElementSize(); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + userMeans.setValue(userIndex, userVector.getElementSize() > 0 ? userVector.getSum(false) / userVector.getElementSize() : meanScore); + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/collaborative.txt b/src/main/java/com/jstarcraft/rns/model/collaborative/collaborative.txt new file mode 100644 index 0000000..6ad70a3 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/collaborative.txt @@ -0,0 +1,2 @@ +协同过滤推荐算法总结: +http://www.cnblogs.com/pinard/p/6349233.html \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/AoBPRModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/AoBPRModel.java new file mode 100644 index 0000000..fa263eb --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/AoBPRModel.java @@ -0,0 +1,185 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.LogisticUtility; +import com.jstarcraft.rns.utility.SampleUtility; + +import it.unimi.dsi.fastutil.ints.IntSet; + +/** + * + * AoBPR推荐器 + * + *
+ * AoBPR: BPR with Adaptive Oversampling
+ * Improving pairwise learning for item recommendation from implicit feedback
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class AoBPRModel extends MatrixFactorizationModel { + private int loopNumber; + + /** + * item geometric distribution parameter + */ + private int lambdaItem; + + // TODO 考虑修改为矩阵和向量 + private float[] factorVariances; + private int[][] factorRanks; + private DenseVector rankProbabilities; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // set for this alg + lambdaItem = (int) (configuration.getFloat("recommender.item.distribution.parameter") * itemSize); + // lamda_Item=500; + loopNumber = (int) (itemSize * Math.log(itemSize)); + + factorVariances = new float[factorSize]; + factorRanks = new int[factorSize][itemSize]; + } + + @Override + protected void doPractice() { + // 排序列表 + List> sortList = new ArrayList<>(itemSize); + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + rankProbabilities = DenseVector.valueOf(itemSize); + rankProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + sortList.add(new KeyValue<>(index, 0F)); + float value = (float) Math.exp(-(index + 1) / lambdaItem); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + List userItemSet = getUserItemSet(scoreMatrix); + + // TODO 此处需要重构 + List userIndexes = new ArrayList<>(actionSize), itemIndexes = new ArrayList<>(actionSize); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + userIndexes.add(userIndex); + itemIndexes.add(itemIndex); + } + + // randoms get a f by p(f|c) + DenseVector factorProbabilities = DenseVector.valueOf(factorSize); + + int sampleCount = 0; + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int sampleIndex = 0, sampleTimes = userSize * 100; sampleIndex < sampleTimes; sampleIndex++) { + // update Ranking every |I|log|I| + if (sampleCount % loopNumber == 0) { + updateSortListByFactor(sortList); + sampleCount = 0; + } + sampleCount++; + + // randomly draw (u, i, j) + int userIndex, positiveItemIndex, negativeItemIndex; + while (true) { + int random = RandomUtility.randomInteger(actionSize); + userIndex = userIndexes.get(random); + IntSet itemSet = userItemSet.get(userIndex); + if (itemSet.size() == 0 || itemSet.size() == itemSize) { + continue; + } + positiveItemIndex = itemIndexes.get(random); + // 计算概率 + DenseVector factorVector = userFactors.getRowVector(userIndex); + sum.setValue(0F); + factorProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = Math.abs(factorVector.getValue(index)) * factorVariances[index]; + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + do { + // randoms get a r by exp(-r/lamda) + int rankIndex = SampleUtility.binarySearch(rankProbabilities, 0, rankProbabilities.getElementSize() - 1, RandomUtility.randomFloat(rankProbabilities.getValue(rankProbabilities.getElementSize() - 1))); + int factorIndex = SampleUtility.binarySearch(factorProbabilities, 0, factorProbabilities.getElementSize() - 1, RandomUtility.randomFloat(factorProbabilities.getValue(factorProbabilities.getElementSize() - 1))); + // get the r-1 in f item + if (userFactors.getValue(userIndex, factorIndex) > 0) { + negativeItemIndex = factorRanks[factorIndex][rankIndex]; + } else { + negativeItemIndex = factorRanks[factorIndex][itemSize - rankIndex - 1]; + } + } while (itemSet.contains(negativeItemIndex)); + break; + } + + // update parameters + float positiveScore = predict(userIndex, positiveItemIndex); + float negativeScore = predict(userIndex, negativeItemIndex); + float error = positiveScore - negativeScore; + float value = (float) -Math.log(LogisticUtility.getValue(error)); + totalError += value; + value = LogisticUtility.getValue(-error); + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float positiveFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float negativeFactor = itemFactors.getValue(negativeItemIndex, factorIndex); + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (value * (positiveFactor - negativeFactor) - userRegularization * userFactor)); + itemFactors.shiftValue(positiveItemIndex, factorIndex, learnRatio * (value * userFactor - itemRegularization * positiveFactor)); + itemFactors.shiftValue(negativeItemIndex, factorIndex, learnRatio * (value * (-userFactor) - itemRegularization * negativeFactor)); + totalError += userRegularization * userFactor * userFactor + itemRegularization * positiveFactor * positiveFactor + itemRegularization * negativeFactor * negativeFactor; + } + } + + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + // TODO 考虑重构 + private void updateSortListByFactor(List> sortList) { + // echo for each factors + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float sum = 0F; + DenseVector factorVector = itemFactors.getColumnVector(factorIndex); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + float value = factorVector.getValue(itemIndex); + sortList.get(itemIndex).setValue(value); + sum += value; + } + Collections.sort(sortList, (left, right) -> { + // 降序 + return right.getValue().compareTo(left.getValue()); + }); + float mean = sum / factorVector.getElementSize(); + sum = 0F; + for (int sortIndex = 0; sortIndex < itemSize; sortIndex++) { + float value = factorVector.getValue(sortIndex); + sum += (value - mean) * (value - mean); + factorRanks[factorIndex][sortIndex] = sortList.get(sortIndex).getKey(); + } + factorVariances[factorIndex] = sum / factorVector.getElementSize(); + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/AspectModelRankingModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/AspectModelRankingModel.java new file mode 100644 index 0000000..237f17e --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/AspectModelRankingModel.java @@ -0,0 +1,135 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; + +/** + * + * Aspect Model推荐器 + * + *
+ * Latent class models for collaborative filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class AspectModelRankingModel extends ProbabilisticGraphicalModel { + + /** + * Conditional distribution: P(u|z) + */ + private DenseMatrix userProbabilities, userSums; + + /** + * Conditional distribution: P(i|z) + */ + private DenseMatrix itemProbabilities, itemSums; + + /** + * topic distribution: P(z) + */ + private DenseVector topicProbabilities, topicSums; + + private DenseVector probabilities; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + // Initialize topic distribution + // TODO 考虑重构 + topicProbabilities = DenseVector.valueOf(factorSize); + topicSums = DenseVector.valueOf(factorSize); + topicProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(factorSize) + 1); + }); + topicProbabilities.scaleValues(1F / topicProbabilities.getSum(false)); + + userProbabilities = DenseMatrix.valueOf(factorSize, userSize); + userSums = DenseMatrix.valueOf(factorSize, userSize); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + DenseVector probabilityVector = userProbabilities.getRowVector(topicIndex); + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + float value = scalar.getValue(); + scalar.setValue(RandomUtility.randomInteger(userSize) + 1); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + + itemProbabilities = DenseMatrix.valueOf(factorSize, itemSize); + itemSums = DenseMatrix.valueOf(factorSize, itemSize); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + DenseVector probabilityVector = itemProbabilities.getRowVector(topicIndex); + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(itemSize) + 1); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + + probabilities = DenseVector.valueOf(factorSize); + } + + /* + * + */ + @Override + protected void eStep() { + topicSums.setValues(0F); + userSums.setValues(0F); + itemSums.setValues(0F); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + probabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = userProbabilities.getValue(index, userIndex) * itemProbabilities.getValue(index, itemIndex) * topicProbabilities.getValue(index); + scalar.setValue(value); + }); + probabilities.scaleValues(1F / probabilities.getSum(false)); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float value = probabilities.getValue(topicIndex) * term.getValue(); + topicSums.shiftValue(topicIndex, value); + userSums.shiftValue(topicIndex, userIndex, value); + itemSums.shiftValue(topicIndex, itemIndex, value); + } + } + } + + @Override + protected void mStep() { + float scale = 1F / topicSums.getSum(false); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + topicProbabilities.setValue(topicIndex, topicSums.getValue(topicIndex) * scale); + float userSum = userProbabilities.getColumnVector(topicIndex).getSum(false); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + userProbabilities.setValue(topicIndex, userIndex, userSums.getValue(topicIndex, userIndex) / userSum); + } + float itemSum = itemProbabilities.getColumnVector(topicIndex).getSum(false); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + itemProbabilities.setValue(topicIndex, itemIndex, itemSums.getValue(topicIndex, itemIndex) / itemSum); + } + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value += userProbabilities.getValue(topicIndex, userIndex) * itemProbabilities.getValue(topicIndex, itemIndex) * topicProbabilities.getValue(topicIndex); + } + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/BHFreeRankingModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/BHFreeRankingModel.java new file mode 100644 index 0000000..169142f --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/BHFreeRankingModel.java @@ -0,0 +1,39 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Map.Entry; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.rns.model.collaborative.BHFreeModel; + +/** + * + * BH Free推荐器 + * + *
+ * Balancing Prediction and Recommendation Accuracy: Hierarchical Latent Factors for Preference Data
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class BHFreeRankingModel extends BHFreeModel { + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = 0F; + for (Entry entry : scoreIndexes.entrySet()) { + float score = entry.getKey(); + float probability = 0F; + for (int userTopic = 0; userTopic < userTopicSize; userTopic++) { + for (int itemTopic = 0; itemTopic < itemTopicSize; itemTopic++) { + probability += user2TopicProbabilities.getValue(userIndex, userTopic) * userTopic2ItemTopicProbabilities.getValue(userTopic, itemTopic) * userTopic2ItemTopicItemSums[userTopic][itemTopic][itemIndex] * userTopic2ItemTopicScoreProbabilities[userTopic][itemTopic][entry.getValue()]; + } + } + value += score * probability; + } + instance.setQuantityMark(value); + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/BPRModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/BPRModel.java new file mode 100644 index 0000000..7b1f386 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/BPRModel.java @@ -0,0 +1,74 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +/** + * + * BPR推荐器 + * + *
+ * BPR: Bayesian Personalized Ranking from Implicit Feedback
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class BPRModel extends MatrixFactorizationModel { + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int sampleIndex = 0, sampleTimes = userSize * 100; sampleIndex < sampleTimes; sampleIndex++) { + // randomly draw (userIdx, posItemIdx, negItemIdx) + int userIndex, positiveItemIndex, negativeItemIndex; + while (true) { + userIndex = RandomUtility.randomInteger(userSize); + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0) { + continue; + } + positiveItemIndex = userVector.getIndex(RandomUtility.randomInteger(userVector.getElementSize())); + negativeItemIndex = RandomUtility.randomInteger(itemSize - userVector.getElementSize()); + for (VectorScalar term : userVector) { + if (negativeItemIndex >= term.getIndex()) { + negativeItemIndex++; + } else { + break; + } + } + break; + } + + // update parameters + float positiveScore = predict(userIndex, positiveItemIndex); + float negativeScore = predict(userIndex, negativeItemIndex); + float error = positiveScore - negativeScore; + float value = (float) -Math.log(LogisticUtility.getValue(error)); + totalError += value; + value = LogisticUtility.getValue(-error); + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float positiveFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float negativeFactor = itemFactors.getValue(negativeItemIndex, factorIndex); + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (value * (positiveFactor - negativeFactor) - userRegularization * userFactor)); + itemFactors.shiftValue(positiveItemIndex, factorIndex, learnRatio * (value * userFactor - itemRegularization * positiveFactor)); + itemFactors.shiftValue(negativeItemIndex, factorIndex, learnRatio * (value * (-userFactor) - itemRegularization * negativeFactor)); + totalError += userRegularization * userFactor * userFactor + itemRegularization * positiveFactor * positiveFactor + itemRegularization * negativeFactor * negativeFactor; + } + } + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/BUCMRankingModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/BUCMRankingModel.java new file mode 100644 index 0000000..47a71d3 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/BUCMRankingModel.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Map.Entry; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.rns.model.collaborative.BUCMModel; + +/** + * + * BUCM推荐器 + * + *
+ * Bayesian User Community Model
+ * Modeling Item Selection and Relevance for Accurate Recommendations
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class BUCMRankingModel extends BUCMModel { + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = 0F; + for (int topicIndex = 0; topicIndex < factorSize; ++topicIndex) { + float sum = 0F; + for (Entry term : scoreIndexes.entrySet()) { + double score = term.getKey(); + if (score > meanScore) { + sum += topicItemScoreProbabilities[topicIndex][itemIndex][term.getValue()]; + } + } + value += userTopicProbabilities.getValue(userIndex, topicIndex) * topicItemProbabilities.getValue(topicIndex, itemIndex) * sum; + } + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/CLiMFModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/CLiMFModel.java new file mode 100644 index 0000000..5b2a90f --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/CLiMFModel.java @@ -0,0 +1,139 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.HashMap; +import java.util.List; + +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +import it.unimi.dsi.fastutil.ints.IntSet; + +/** + * + * Random Guess推荐器 + * + *
+ * CLiMF: learning to maximize reciprocal rank with collaborative less-is-more filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class CLiMFModel extends MatrixFactorizationModel { + + @Override + protected void doPractice() { + List userItemSet = getUserItemSet(scoreMatrix); + + float[] factorValues = new float[factorSize]; + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + // TODO 此处应该考虑重构,不再使用itemSet + IntSet itemSet = userItemSet.get(userIndex); + + // 缓存预测值 + DenseVector predictVector = DenseVector.valueOf(itemSet.size()); + DenseVector logisticVector = DenseVector.valueOf(itemSet.size()); + int index = 0; + for (int itemIndex : itemSet) { + float value = predict(userIndex, itemIndex); + predictVector.setValue(index, value); + logisticVector.setValue(index, LogisticUtility.getValue(-value)); + index++; + } + DenseMatrix logisticMatrix = DenseMatrix.valueOf(itemSet.size(), itemSet.size()); + DenseMatrix gradientMatrix = DenseMatrix.valueOf(itemSet.size(), itemSet.size()); + gradientMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = predictVector.getValue(row) - predictVector.getValue(column); + float logistic = LogisticUtility.getValue(value); + logisticMatrix.setValue(row, column, logistic); + float gradient = LogisticUtility.getGradient(value); + scalar.setValue(gradient); + }); + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float factorValue = -userRegularization * userFactors.getValue(userIndex, factorIndex); + int leftIndex = 0; + for (int itemIndex : itemSet) { + float itemFactorValue = itemFactors.getValue(itemIndex, factorIndex); + factorValue += logisticVector.getValue(leftIndex) * itemFactorValue; + // TODO 此处应该考虑对称性减少迭代次数 + int rightIndex = 0; + for (int compareIndex : itemSet) { + if (compareIndex != itemIndex) { + float compareValue = itemFactors.getValue(compareIndex, factorIndex); + factorValue += gradientMatrix.getValue(rightIndex, leftIndex) / (1 - logisticMatrix.getValue(rightIndex, leftIndex)) * (itemFactorValue - compareValue); + } + rightIndex++; + } + leftIndex++; + } + factorValues[factorIndex] = factorValue; + } + + int leftIndex = 0; + for (int itemIndex : itemSet) { + float logisticValue = logisticVector.getValue(leftIndex); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactorValue = userFactors.getValue(userIndex, factorIndex); + float itemFactorValue = itemFactors.getValue(itemIndex, factorIndex); + float judgeValue = 1F; + float factorValue = judgeValue * logisticValue * userFactorValue - itemRegularization * itemFactorValue; + // TODO 此处应该考虑对称性减少迭代次数 + int rightIndex = 0; + for (int compareIndex : itemSet) { + if (compareIndex != itemIndex) { + factorValue += gradientMatrix.getValue(rightIndex, leftIndex) * (judgeValue / (judgeValue - logisticMatrix.getValue(rightIndex, leftIndex)) - judgeValue / (judgeValue - logisticMatrix.getValue(leftIndex, rightIndex))) * userFactorValue; + } + rightIndex++; + } + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * factorValue); + } + leftIndex++; + } + + for (int factorIdx = 0; factorIdx < factorSize; factorIdx++) { + userFactors.shiftValue(userIndex, factorIdx, learnRatio * factorValues[factorIdx]); + } + + // TODO 获取预测值 + HashMap predictMap = new HashMap<>(itemSet.size()); + for (int itemIndex : itemSet) { + float predictValue = predict(userIndex, itemIndex); + predictMap.put(itemIndex, predictValue); + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + if (itemSet.contains(itemIndex)) { + float predictValue = predictMap.get(itemIndex); + totalError += (float) Math.log(LogisticUtility.getValue(predictValue)); + // TODO 此处应该考虑对称性减少迭代次数 + for (int compareIndex : itemSet) { + float compareValue = predictMap.get(compareIndex); + totalError += (float) Math.log(1 - LogisticUtility.getValue(compareValue - predictValue)); + } + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactorValue = userFactors.getValue(userIndex, factorIndex); + float itemFactorValue = itemFactors.getValue(itemIndex, factorIndex); + totalError += -0.5 * (userRegularization * userFactorValue * userFactorValue + itemRegularization * itemFactorValue * itemFactorValue); + } + } + } + + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/EALSModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/EALSModel.java new file mode 100644 index 0000000..a604df3 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/EALSModel.java @@ -0,0 +1,237 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.concurrent.CountDownLatch; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.environment.EnvironmentContext; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.model.exception.ModelException; + +/** + * + * EALS推荐器 + * + *
+ * EALS: efficient Alternating Least Square for Weighted Regularized Matrix Factorization
+ * Collaborative filtering for implicit feedback dataset
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class EALSModel extends MatrixFactorizationModel { + /** + * confidence weight coefficient for WRMF + */ + protected float weightCoefficient; + + /** + * the significance level of popular items over un-popular ones + */ + private float ratio; + + /** + * the overall weight of missing data c0 + */ + private float overallWeight; + + /** + * 0:eALS MF; 1:WRMF; 2: both + */ + private int type; + + /** + * confidence that item i missed by users + */ + private float[] confidences; + + /** + * weights of all user-item pair (u,i) + */ + private SparseMatrix weights; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + weightCoefficient = configuration.getFloat("recommender.wrmf.weight.coefficient", 4.0f); + ratio = configuration.getFloat("recommender.eals.ratio", 0.4f); + overallWeight = configuration.getFloat("recommender.eals.overall", 128.0f); + type = configuration.getInteger("recommender.eals.wrmf.judge", 1); + + confidences = new float[itemSize]; + + // get ci + if (type == 0 || type == 2) { + float sumPopularity = 0F; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + float alphaPopularity = (float) Math.pow(scoreMatrix.getColumnScope(itemIndex) * 1.0 / actionSize, ratio); + confidences[itemIndex] = overallWeight * alphaPopularity; + sumPopularity += alphaPopularity; + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + confidences[itemIndex] = confidences[itemIndex] / sumPopularity; + } + } else { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + confidences[itemIndex] = 1; + } + } + + weights = SparseMatrix.copyOf(scoreMatrix, false); + weights.iterateElement(MathCalculator.SERIAL, (scalar) -> { + if (type == 1 || type == 2) { + scalar.setValue(1F + weightCoefficient * scalar.getValue()); + } else { + scalar.setValue(1F); + } + }); + } + + private ThreadLocal itemScoreStorage = new ThreadLocal<>(); + private ThreadLocal itemWeightStorage = new ThreadLocal<>(); + private ThreadLocal userScoreStorage = new ThreadLocal<>(); + private ThreadLocal userWeightStorage = new ThreadLocal<>(); + + @Override + protected void constructEnvironment() { + // TODO 可以继续节省数组分配的大小(按照稀疏矩阵的最大向量作为缓存大小). + itemScoreStorage.set(new float[itemSize]); + itemWeightStorage.set(new float[itemSize]); + userScoreStorage.set(new float[userSize]); + userWeightStorage.set(new float[userSize]); + } + + @Override + protected void destructEnvironment() { + itemScoreStorage.remove(); + itemWeightStorage.remove(); + userScoreStorage.remove(); + userWeightStorage.remove(); + } + + @Override + protected void doPractice() { + EnvironmentContext context = EnvironmentContext.getContext(); + DenseMatrix itemDeltas = DenseMatrix.valueOf(factorSize, factorSize); + DenseMatrix userDeltas = DenseMatrix.valueOf(factorSize, factorSize); + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + // Update the Sq cache + for (int leftFactorIndex = 0; leftFactorIndex < factorSize; leftFactorIndex++) { + for (int rightFactorIndex = leftFactorIndex; rightFactorIndex < factorSize; rightFactorIndex++) { + float value = 0F; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + value += confidences[itemIndex] * itemFactors.getValue(itemIndex, leftFactorIndex) * itemFactors.getValue(itemIndex, rightFactorIndex); + } + itemDeltas.setValue(leftFactorIndex, rightFactorIndex, value); + itemDeltas.setValue(rightFactorIndex, leftFactorIndex, value); + } + } + // Step 1: update user factors; + // 按照用户切割任务实现并发计算. + CountDownLatch userLatch = new CountDownLatch(userSize); + for (int index = 0; index < userSize; index++) { + int userIndex = index; + context.doAlgorithmByAny(index, () -> { + DefaultScalar scalar = DefaultScalar.getInstance(); + SparseVector userVector = weights.getRowVector(userIndex); + DenseVector factorVector = userFactors.getRowVector(userIndex); + float[] itemScores = itemScoreStorage.get(); + float[] itemWeights = itemWeightStorage.get(); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + itemScores[itemIndex] = scalar.dotProduct(itemVector, factorVector).getValue(); + itemWeights[itemIndex] = term.getValue(); + } + for (int leftFactorIndex = 0; leftFactorIndex < factorSize; leftFactorIndex++) { + float numerator = 0, denominator = userRegularization + itemDeltas.getValue(leftFactorIndex, leftFactorIndex); + for (int rightFactorIndex = 0; rightFactorIndex < factorSize; rightFactorIndex++) { + if (leftFactorIndex != rightFactorIndex) { + numerator -= userFactors.getValue(userIndex, rightFactorIndex) * itemDeltas.getValue(leftFactorIndex, rightFactorIndex); + } + } + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + itemScores[itemIndex] -= userFactors.getValue(userIndex, leftFactorIndex) * itemFactors.getValue(itemIndex, leftFactorIndex); + numerator += (itemWeights[itemIndex] - (itemWeights[itemIndex] - confidences[itemIndex]) * itemScores[itemIndex]) * itemFactors.getValue(itemIndex, leftFactorIndex); + denominator += (itemWeights[itemIndex] - confidences[itemIndex]) * itemFactors.getValue(itemIndex, leftFactorIndex) * itemFactors.getValue(itemIndex, leftFactorIndex); + } + // update puf + userFactors.setValue(userIndex, leftFactorIndex, numerator / denominator); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + itemScores[itemIndex] += userFactors.getValue(userIndex, leftFactorIndex) * itemFactors.getValue(itemIndex, leftFactorIndex); + } + } + userLatch.countDown(); + }); + } + try { + userLatch.await(); + } catch (Exception exception) { + throw new ModelException(exception); + } + + // Update the Sp cache + userDeltas.dotProduct(userFactors, true, userFactors, false, MathCalculator.SERIAL); + // Step 2: update item factors; + // 按照物品切割任务实现并发计算. + CountDownLatch itemLatch = new CountDownLatch(itemSize); + for (int index = 0; index < itemSize; index++) { + int itemIndex = index; + context.doAlgorithmByAny(index, () -> { + DefaultScalar scalar = DefaultScalar.getInstance(); + SparseVector itemVector = weights.getColumnVector(itemIndex); + DenseVector factorVector = itemFactors.getRowVector(itemIndex); + float[] userScores = userScoreStorage.get(); + float[] userWeights = userWeightStorage.get(); + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + DenseVector userVector = userFactors.getRowVector(userIndex); + userScores[userIndex] = scalar.dotProduct(userVector, factorVector).getValue(); + userWeights[userIndex] = term.getValue(); + } + for (int leftFactorIndex = 0; leftFactorIndex < factorSize; leftFactorIndex++) { + float numerator = 0, denominator = confidences[itemIndex] * userDeltas.getValue(leftFactorIndex, leftFactorIndex) + itemRegularization; + for (int rightFactorIndex = 0; rightFactorIndex < factorSize; rightFactorIndex++) { + if (leftFactorIndex != rightFactorIndex) { + numerator -= itemFactors.getValue(itemIndex, rightFactorIndex) * userDeltas.getValue(rightFactorIndex, leftFactorIndex); + } + } + numerator *= confidences[itemIndex]; + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + userScores[userIndex] -= userFactors.getValue(userIndex, leftFactorIndex) * itemFactors.getValue(itemIndex, leftFactorIndex); + numerator += (userWeights[userIndex] - (userWeights[userIndex] - confidences[itemIndex]) * userScores[userIndex]) * userFactors.getValue(userIndex, leftFactorIndex); + denominator += (userWeights[userIndex] - confidences[itemIndex]) * userFactors.getValue(userIndex, leftFactorIndex) * userFactors.getValue(userIndex, leftFactorIndex); + } + // update qif + itemFactors.setValue(itemIndex, leftFactorIndex, numerator / denominator); + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + userScores[userIndex] += userFactors.getValue(userIndex, leftFactorIndex) * itemFactors.getValue(itemIndex, leftFactorIndex); + } + } + itemLatch.countDown(); + }); + } + try { + itemLatch.await(); + } catch (Exception exception) { + throw new ModelException(exception); + } + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/FISMAUCModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/FISMAUCModel.java new file mode 100644 index 0000000..1012b7d --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/FISMAUCModel.java @@ -0,0 +1,245 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * FISM-AUC推荐器 + * + *
+ * FISM: Factored Item Similarity Models for Top-N Recommender Systems
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +// 注意:FISM使用itemFactors来组成userFactors +public class FISMAUCModel extends MatrixFactorizationModel { + + private float rho, alpha, beta, gamma; + + /** + * bias regularization + */ + private float biasRegularization; + + /** + * items and users biases vector + */ + private DenseVector itemBiases; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // 注意:FISM使用itemFactors来组成userFactors + userFactors = DenseMatrix.valueOf(itemSize, factorSize); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + // TODO + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + rho = configuration.getFloat("recommender.fismauc.rho");// 3-15 + alpha = configuration.getFloat("recommender.fismauc.alpha", 0.5F); + beta = configuration.getFloat("recommender.fismauc.beta", 0.6F); + gamma = configuration.getFloat("recommender.fismauc.gamma", 0.1F); + biasRegularization = configuration.getFloat("recommender.iteration.learnrate", 0.0001F); + // cacheSpec = conf.get("guava.cache.spec", + // "maximumSize=200,expireAfterAccess=2m"); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + // x <- 0 + DenseVector userVector = DenseVector.valueOf(factorSize); + // t <- (n - 1)^(-alpha) Σ pj (j!=i) + DenseVector itemVector = DenseVector.valueOf(factorSize); + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + // for all u in C + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector rateVector = scoreMatrix.getRowVector(userIndex); + int size = rateVector.getElementSize(); + if (size == 0 || size == 1) { + size = 2; + } + // for all i in Ru+ + for (VectorScalar positiveTerm : rateVector) { + int positiveIndex = positiveTerm.getIndex(); + userVector.setValues(0F); + itemVector.setValues(0F); + for (VectorScalar negativeTerm : rateVector) { + int negativeIndex = negativeTerm.getIndex(); + if (positiveIndex != negativeIndex) { + itemVector.addVector(userFactors.getRowVector(negativeIndex)); + } + } + itemVector.scaleValues((float) Math.pow(size - 1, -alpha)); + // Z <- SampleZeros(rho) + int sampleSize = (int) (rho * size); + // make a random sample of negative feedback for Ru- + List negativeIndexes = new LinkedList<>(); + for (int sampleIndex = 0; sampleIndex < sampleSize; sampleIndex++) { + int negativeItemIndex = RandomUtility.randomInteger(itemSize - negativeIndexes.size()); + int index = 0; + for (int negativeIndex : negativeIndexes) { + if (negativeItemIndex >= negativeIndex) { + negativeItemIndex++; + index++; + } else { + break; + } + } + negativeIndexes.add(index, negativeItemIndex); + } + + int leftCursor = 0, rightCursor = 0, leftSize = rateVector.getElementSize(), rightSize = sampleSize; + if (leftSize != 0 && rightSize != 0) { + Iterator leftIterator = rateVector.iterator(); + Iterator rightIterator = negativeIndexes.iterator(); + VectorScalar leftTerm = leftIterator.next(); + int negativeItemIndex = rightIterator.next(); + // 判断两个有序数组中是否存在相同的数字 + while (leftCursor < leftSize && rightCursor < rightSize) { + if (leftTerm.getIndex() == negativeItemIndex) { + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + rightIterator.remove(); + if (rightIterator.hasNext()) { + negativeItemIndex = rightIterator.next(); + } + leftCursor++; + rightCursor++; + } else if (leftTerm.getIndex() > negativeItemIndex) { + if (rightIterator.hasNext()) { + negativeItemIndex = rightIterator.next(); + } + rightCursor++; + } else if (leftTerm.getIndex() < negativeItemIndex) { + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + leftCursor++; + } + } + } + + // for all j in Z + for (int negativeIndex : negativeIndexes) { + // update pui puj rui ruj + float positiveScore = positiveTerm.getValue(); + float negativeScore = 0F; + float positiveBias = itemBiases.getValue(positiveIndex); + float negativeBias = itemBiases.getValue(negativeIndex); + float positiveFactor = positiveBias + scalar.dotProduct(itemFactors.getRowVector(positiveIndex), itemVector).getValue(); + float negativeFactor = negativeBias + scalar.dotProduct(itemFactors.getRowVector(negativeIndex), itemVector).getValue(); + + float error = (positiveScore - negativeScore) - (positiveFactor - negativeFactor); + totalError += error * error; + + // update bi bj + itemBiases.shiftValue(positiveIndex, biasRegularization * (error - gamma * positiveBias)); + itemBiases.shiftValue(negativeIndex, biasRegularization * (error - gamma * negativeBias)); + + // update qi qj + DenseVector positiveVector = itemFactors.getRowVector(positiveIndex); + positiveVector.iterateElement(MathCalculator.SERIAL, (element) -> { + int index = element.getIndex(); + float value = element.getValue(); + element.setValue(value + (itemVector.getValue(index) * error - value * beta) * biasRegularization); + }); + DenseVector negativeVector = itemFactors.getRowVector(negativeIndex); + negativeVector.iterateElement(MathCalculator.SERIAL, (element) -> { + int index = element.getIndex(); + float value = element.getValue(); + element.setValue(value - (itemVector.getValue(index) * error - value * beta) * biasRegularization); + }); + + // update x + userVector.iterateElement(MathCalculator.SERIAL, (element) -> { + int index = element.getIndex(); + float value = element.getValue(); + element.setValue(value + (positiveVector.getValue(index) - negativeVector.getValue(index)) * error); + }); + } + + float scale = (float) (Math.pow(rho, -1) * Math.pow(size - 1, -alpha)); + + // for all j in Ru+\{i} + for (VectorScalar term : rateVector) { + int negativeIndex = term.getIndex(); + if (negativeIndex != positiveIndex) { + // update pj + DenseVector negativeVector = userFactors.getRowVector(negativeIndex); + negativeVector.iterateElement(MathCalculator.SERIAL, (element) -> { + int index = element.getIndex(); + float value = element.getValue(); + element.setValue((userVector.getValue(index) * scale - value * beta) * biasRegularization + value); + }); + } + } + } + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + double itemBias = itemBiases.getValue(itemIndex); + totalError += gamma * itemBias * itemBias; + totalError += beta * scalar.dotProduct(itemFactors.getRowVector(itemIndex), itemFactors.getRowVector(itemIndex)).getValue(); + totalError += beta * scalar.dotProduct(userFactors.getRowVector(itemIndex), userFactors.getRowVector(itemIndex)).getValue(); + } + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + float bias = itemBiases.getValue(itemIndex); + float sum = 0F; + int count = 0; + for (VectorScalar term : scoreMatrix.getRowVector(userIndex)) { + int compareIndex = term.getIndex(); + // for test, i and j will be always unequal as j is unrated + if (compareIndex != itemIndex) { + DenseVector compareVector = userFactors.getRowVector(compareIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + sum += scalar.dotProduct(compareVector, itemVector).getValue(); + count++; + } + } + sum *= (float) (count > 0 ? Math.pow(count, -alpha) : 0F); + instance.setQuantityMark(bias + sum); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/FISMRMSEModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/FISMRMSEModel.java new file mode 100644 index 0000000..d02f37d --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/FISMRMSEModel.java @@ -0,0 +1,202 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * + * FISM-RMSE推荐器 + * + *
+ * FISM: Factored Item Similarity Models for Top-N Recommender Systems
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +// 注意:FISM使用itemFactors来组成userFactors +public class FISMRMSEModel extends MatrixFactorizationModel { + + private int numNeighbors; + + private float rho, alpha, beta, itemRegularization, userRegularization; + + /** + * bias regularization + */ + private float learnRatio; + + /** + * items and users biases vector + */ + private DenseVector itemBiases, userBiases; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // 注意:FISM使用itemFactors来组成userFactors + userFactors = DenseMatrix.valueOf(itemSize, factorSize); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + userBiases = DenseVector.valueOf(userSize); + userBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + numNeighbors = scoreMatrix.getElementSize(); + rho = configuration.getFloat("recommender.fismrmse.rho");// 3-15 + alpha = configuration.getFloat("recommender.fismrmse.alpha", 0.5F); + beta = configuration.getFloat("recommender.fismrmse.beta", 0.6F); + itemRegularization = configuration.getFloat("recommender.fismrmse.gamma", 0.1F); + userRegularization = configuration.getFloat("recommender.fismrmse.gamma", 0.1F); + learnRatio = configuration.getFloat("recommender.fismrmse.lrate", 0.0001F); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + int sampleSize = (int) (rho * numNeighbors); + int totalSize = userSize * itemSize; + HashMatrix rateMatrix = new HashMatrix(true, userSize, itemSize, new Long2FloatRBTreeMap()); + for (MatrixScalar cell : scoreMatrix) { + rateMatrix.setValue(cell.getRow(), cell.getColumn(), cell.getValue()); + } + int[] sampleIndexes = new int[sampleSize]; + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + DenseVector userVector = DenseVector.valueOf(factorSize); + totalError = 0F; + // new training data by sampling negative values + // R是一个在trainMatrix基础上增加负样本的矩阵. + + // make a random sample of negative feedback (total - nnz) + for (int sampleIndex = 0; sampleIndex < sampleSize; sampleIndex++) { + while (true) { + int randomIndex = RandomUtility.randomInteger(totalSize - numNeighbors); + int rowIndex = randomIndex / itemSize; + int columnIndex = randomIndex % itemSize; + + if (Float.isNaN(rateMatrix.getValue(rowIndex, columnIndex))) { + sampleIndexes[sampleIndex] = randomIndex; + rateMatrix.setValue(rowIndex, columnIndex, 0F); + break; + } + } + } + + // update throughout each user-item-rating (u, i, rui) cell + for (MatrixScalar cell : rateMatrix) { + int userIndex = cell.getRow(); + int itemIndex = cell.getColumn(); + float score = cell.getValue(); + SparseVector rateVector = scoreMatrix.getRowVector(userIndex); + int size = rateVector.getElementSize() - 1; + if (size == 0 || size == -1) { + size = 1; + } + for (VectorScalar term : rateVector) { + int compareIndex = term.getIndex(); + if (itemIndex != compareIndex) { + userVector.addVector(userFactors.getRowVector(compareIndex)); + } + } + userVector.scaleValues((float) Math.pow(size, -alpha)); + // for efficiency, use the below code to predict rui instead of + // simply using "predict(u,j)" + float itemBias = itemBiases.getValue(itemIndex); + float userBias = userBiases.getValue(userIndex); + float predict = itemBias + userBias + scalar.dotProduct(itemFactors.getRowVector(itemIndex), userVector).getValue(); + float error = score - predict; + totalError += error * error; + // update bi + itemBiases.shiftValue(itemIndex, learnRatio * (error - itemRegularization * itemBias)); + totalError += itemRegularization * itemBias * itemBias; + // update bu + userBiases.shiftValue(userIndex, learnRatio * (error - userRegularization * userBias)); + totalError += userRegularization * userBias * userBias; + + DenseVector factorVector = itemFactors.getRowVector(itemIndex); + factorVector.iterateElement(MathCalculator.SERIAL, (element) -> { + int index = element.getIndex(); + float value = element.getValue(); + element.setValue((userVector.getValue(index) * error - value * beta) * learnRatio + value); + }); + totalError += beta * scalar.dotProduct(factorVector, factorVector).getValue(); + + for (VectorScalar term : rateVector) { + int compareIndex = term.getIndex(); + if (itemIndex != compareIndex) { + float scale = (float) (error * Math.pow(size, -alpha)); + factorVector = userFactors.getRowVector(compareIndex); + factorVector.iterateElement(MathCalculator.SERIAL, (element) -> { + int index = element.getIndex(); + float value = element.getValue(); + element.setValue((value * scale - value * beta) * learnRatio + value); + }); + totalError += beta * scalar.dotProduct(factorVector, factorVector).getValue(); + } + } + } + + for (int sampleIndex : sampleIndexes) { + int rowIndex = sampleIndex / itemSize; + int columnIndex = sampleIndex % itemSize; + rateMatrix.setValue(rowIndex, columnIndex, Float.NaN); + } + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + float bias = userBiases.getValue(userIndex) + itemBiases.getValue(itemIndex); + float sum = 0F; + int count = 0; + for (VectorScalar term : scoreMatrix.getRowVector(userIndex)) { + int index = term.getIndex(); + // for test, i and j will be always unequal as j is unrated + if (index != itemIndex) { + DenseVector userVector = userFactors.getRowVector(index); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + sum += scalar.dotProduct(userVector, itemVector).getValue(); + count++; + } + } + sum *= (float) (count > 0 ? Math.pow(count, -alpha) : 0F); + instance.setQuantityMark(bias + sum); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/GBPRModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/GBPRModel.java new file mode 100644 index 0000000..cb2c3cf --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/GBPRModel.java @@ -0,0 +1,177 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.HashSet; +import java.util.Set; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +/** + * + * Random Guess推荐器 + * + *
+ * GBPR: Group Preference Based Bayesian Personalized Ranking for One-Class Collaborative Filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class GBPRModel extends MatrixFactorizationModel { + + private float rho; + + private int gLen; + + /** + * bias regularization + */ + private float regBias; + + /** + * items biases vector + */ + private DenseVector itemBiases; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + + rho = configuration.getFloat("recommender.gpbr.rho", 1.5f); + gLen = configuration.getInteger("recommender.gpbr.gsize", 2); + } + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + // TODO 考虑重构 + DenseMatrix userDeltas = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix itemDeltas = DenseMatrix.valueOf(itemSize, factorSize); + + for (int sampleIndex = 0, sampleTimes = userSize * 100; sampleIndex < sampleTimes; sampleIndex++) { + int userIndex, positiveItemIndex, negativeItemIndex; + SparseVector userVector; + do { + userIndex = RandomUtility.randomInteger(userSize); + userVector = scoreMatrix.getRowVector(userIndex); + } while (userVector.getElementSize() == 0); + positiveItemIndex = userVector.getIndex(RandomUtility.randomInteger(userVector.getElementSize())); + + // users group Set + Set memberSet = new HashSet<>(); + SparseVector positiveItemVector = scoreMatrix.getColumnVector(positiveItemIndex); + if (positiveItemVector.getElementSize() <= gLen) { + for (VectorScalar entry : positiveItemVector) { + memberSet.add(entry.getIndex()); + } + } else { + memberSet.add(userIndex); // u in G + while (memberSet.size() < gLen) { + memberSet.add(positiveItemVector.getIndex(RandomUtility.randomInteger(positiveItemVector.getElementSize()))); + } + } + float positiveScore = predict(userIndex, positiveItemIndex, memberSet); + negativeItemIndex = RandomUtility.randomInteger(itemSize - userVector.getElementSize()); + for (VectorScalar term : userVector) { + if (negativeItemIndex >= term.getIndex()) { + negativeItemIndex++; + } else { + break; + } + } + float negativeScore = predict(userIndex, negativeItemIndex); + float error = positiveScore - negativeScore; + float value = (float) -Math.log(LogisticUtility.getValue(error)); + totalError += value; + value = LogisticUtility.getValue(-error); + + // update bi, bj + float positiveBias = itemBiases.getValue(positiveItemIndex); + itemBiases.shiftValue(positiveItemIndex, learnRatio * (value - regBias * positiveBias)); + float negativeBias = itemBiases.getValue(negativeItemIndex); + itemBiases.shiftValue(negativeItemIndex, learnRatio * (-value - regBias * negativeBias)); + + // update Pw + float averageWeight = 1F / memberSet.size(); + float memberSums[] = new float[factorSize]; + for (int memberIndex : memberSet) { + float delta = memberIndex == userIndex ? 1F : 0F; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float memberFactor = userFactors.getValue(memberIndex, factorIndex); + float positiveFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float negativeFactor = itemFactors.getValue(negativeItemIndex, factorIndex); + float deltaGroup = rho * averageWeight * positiveFactor + (1 - rho) * delta * positiveFactor - delta * negativeFactor; + userDeltas.shiftValue(memberIndex, factorIndex, learnRatio * (value * deltaGroup - userRegularization * memberFactor)); + memberSums[factorIndex] += memberFactor; + } + } + + // update itemFactors + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float positiveFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float negativeFactor = itemFactors.getValue(negativeItemIndex, factorIndex); + float positiveDelta = rho * averageWeight * memberSums[factorIndex] + (1 - rho) * userFactor; + itemDeltas.shiftValue(positiveItemIndex, factorIndex, learnRatio * (value * positiveDelta - itemRegularization * positiveFactor)); + float negativeDelta = -userFactor; + itemDeltas.shiftValue(negativeItemIndex, factorIndex, learnRatio * (value * negativeDelta - itemRegularization * negativeFactor)); + } + } + userFactors.addMatrix(userDeltas, false); + itemFactors.addMatrix(itemDeltas, false); + + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + private float predict(int userIndex, int itemIndex, Set memberIndexes) { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userFactors.getRowVector(userIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + float value = itemBiases.getValue(itemIndex) + scalar.dotProduct(userVector, itemVector).getValue(); + float sum = 0F; + for (int memberIndex : memberIndexes) { + userVector = userFactors.getRowVector(memberIndex); + sum += scalar.dotProduct(userVector, itemVector).getValue(); + } + float groupScore = sum / memberIndexes.size() + itemBiases.getValue(itemIndex); + return rho * groupScore + (1 - rho) * value; + } + + @Override + protected float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userFactors.getRowVector(userIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + return itemBiases.getValue(itemIndex) + scalar.dotProduct(userVector, itemVector).getValue(); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/HMMModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/HMMModel.java new file mode 100644 index 0000000..bed1eed --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/HMMModel.java @@ -0,0 +1,1008 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Map.Entry; +import java.util.concurrent.CountDownLatch; + +import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Table; +import com.google.common.util.concurrent.AtomicDouble; +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.attribute.MemoryQualityAttribute; +import com.jstarcraft.ai.data.processor.DataSorter; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.ai.environment.EnvironmentContext; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.data.processor.AllFeatureDataSorter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; +import com.jstarcraft.rns.model.exception.ModelException; +import com.jstarcraft.rns.utility.GammaUtility; + +/** + * + * HMMForCF推荐器 + * + *
+ * A Hidden Markov Model Purpose: A class for the model, including parameters
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class HMMModel extends ProbabilisticGraphicalModel { + + private static float nearZero = (float) Math.pow(10F, -10F); + + /** 上下文字段 */ + private String contextField; + + /** 上下文维度 */ + private int contextDimension; + + /** 状态数 */ + private int numberOfStates; + + /** + *
+     * 正则化参数
+     * probabilityRegularization,stateRegularization必须大于numberOfStates
+     * viewRegularization必须大于numberOfItems
+     * 否则可能导致Formula 1.9的viewProbabilities元素可能为NaN
+     * 
+ */ + private float probabilityRegularization, stateRegularization, viewRegularization; + + private DenseVector probabilityNumerator; + private AtomicDouble probabilityDenominator = new AtomicDouble(); + private DenseMatrix stateNumerator; + private DenseVector stateDenominator; + private DenseMatrix viewNumerator; + private DenseVector viewDenominator; + + /** 概率向量(Pi) {numberOfStates} */ + private DenseVector probabilities; + + /** 状态概率矩阵(A) {numberOfStates, numberOfStates} */ + private DenseMatrix stateProbabilities; + + /** 观察概率矩阵(B) {numberOfStates, numberOfItems} */ + private DenseMatrix viewProbabilities; + + // numeratorPsiGamma => {numberOfStates} + private DenseVector numeratorPsiGamma; + // numeratorLogGamma=> {numberOfStates} + private DenseVector numeratorLogGamma; + // denominatorPsiGamma => {numberOfStates} + private DenseVector denominatorPsiGamma; + // denominatorLogGamma => {numberOfStates} + private DenseVector denominatorLogGamma; + // psiNumerator => {numberOfStates} + private DenseVector psiNumerator; + // nutNumerator => {numberOfStates} + private DenseVector nutNumerator; + // next_denominator => {numberOfStates} + private DenseVector averageDenominator; + // numerator => {numberOfStates} + private DenseVector modelNumerator; + // denominator => {numberOfStates} + private DenseVector modelDenominator; + + /** 负二项分布 */ + private DenseVector alpha, beta; + + /** + * TODO 似乎与用户有关联 + * + *
+     * P(Z(u,t)|I(u,1:T)) |numberOfUsers|*(|sizeOfContexts|*|numberOfStates|)
+     * 
+ */ + private DenseMatrix[] gammas; + + /** + * TODO 似乎与上下文有关联 + * + *
+     * P(Z(u,t-1),Z(u,t)|I(u,1:T)) |numberOfUsers|*(|numberOfStates|*|numberOfStates|)
+     * 
+ */ + + /** + * TODO 与总和有关 + * + *
+     * |numberOfUsers|*(|sizeOfContexts|)
+     * 
+ */ + private DenseVector[] nuts; + + private DenseMatrix norms; + + /** 数据矩阵集合 {sizeOfContexts, numberOfItems} */ + private SparseMatrix[] dataMatrixes; + + /** 上下文大小(缓存相关) */ + private int contextSize; + + /** + * 检查模型 + * + * @param vector + * @return + */ + private boolean checkVector(DenseVector vector) { + for (VectorScalar term : vector) { + if (Float.isNaN(term.getValue())) { + return false; + } + if (Float.isInfinite(term.getValue())) { + return false; + } + } + return true; + } + + /** + * 检查模型 + * + * @param matrix + * @return + */ + private boolean checkMatrix(DenseMatrix matrix) { + for (MatrixScalar term : matrix) { + if (Float.isNaN(term.getValue())) { + return false; + } + if (Float.isInfinite(term.getValue())) { + return false; + } + } + return true; + } + + /** + * 检查模型 + * + * @param model + * @return + */ + private boolean checkModel(DenseVector model) { + for (VectorScalar term : model) { + if (Float.isNaN(term.getValue())) { + return false; + } + if (Float.isInfinite(term.getValue())) { + return false; + } + // psiGamma遇到负整数会变为NaN或者无穷. + // logGamma遇到负数会变为NaN或者无穷. + if (term.getValue() <= 0F) { + return false; + } + } + return true; + } + + // 线程缓存 + // E Step + private ThreadLocal alphaStorage = new ThreadLocal<>(); + private ThreadLocal betaStorage = new ThreadLocal<>(); + private ThreadLocal rhoStorage = new ThreadLocal<>(); + private ThreadLocal binomialStorage = new ThreadLocal<>(); + private ThreadLocal multinomialStorage = new ThreadLocal<>(); + private ThreadLocal normContextStorage = new ThreadLocal<>(); + private ThreadLocal normStateStorage = new ThreadLocal<>(); + private ThreadLocal sumGammaStorage = new ThreadLocal<>(); + private ThreadLocal gammaSumStorage = new ThreadLocal<>(); + + // M Step + private ThreadLocal numeratorStorage = new ThreadLocal<>(); + private ThreadLocal denominatorStorage = new ThreadLocal<>(); + private ThreadLocal probabilityNumeratorStorage = new ThreadLocal<>(); + private ThreadLocal probabilityDenominatorStorage = new ThreadLocal<>(); + private ThreadLocal stateNumeratorStorage = new ThreadLocal<>(); + private ThreadLocal viewNumeratorStorage = new ThreadLocal<>(); + private ThreadLocal viewDenominatorStorage = new ThreadLocal<>(); + + @Override + protected void constructEnvironment() { + // E Step并发计算部分 + alphaStorage.set(new float[contextSize * numberOfStates]); + betaStorage.set(new float[contextSize * numberOfStates]); + rhoStorage.set(new float[numberOfStates * numberOfStates]); + normContextStorage.set(new float[contextSize]); + normStateStorage.set(new float[numberOfStates]); + sumGammaStorage.set(new float[contextSize]); + gammaSumStorage.set(new float[contextSize]); + binomialStorage.set(new float[contextSize * numberOfStates]); + multinomialStorage.set(new float[contextSize * numberOfStates]); + + // M Step并发计算部分 + numeratorStorage.set(DenseVector.valueOf(numberOfStates)); + denominatorStorage.set(DenseVector.valueOf(numberOfStates)); + + probabilityNumeratorStorage.set(DenseVector.valueOf(numberOfStates)); + probabilityDenominatorStorage.set(new AtomicDouble()); + stateNumeratorStorage.set(DenseMatrix.valueOf(numberOfStates, numberOfStates)); + viewNumeratorStorage.set(DenseMatrix.valueOf(numberOfStates, itemSize)); + viewDenominatorStorage.set(DenseVector.valueOf(numberOfStates)); + } + + @Override + protected void destructEnvironment() { + // E Step并发计算部分 + alphaStorage.remove(); + betaStorage.remove(); + rhoStorage.remove(); + normContextStorage.remove(); + normStateStorage.remove(); + sumGammaStorage.remove(); + gammaSumStorage.remove(); + binomialStorage.remove(); + multinomialStorage.remove(); + + // M Step并发计算部分 + numeratorStorage.remove(); + denominatorStorage.remove(); + + probabilityNumeratorStorage.remove(); + probabilityDenominatorStorage.remove(); + stateNumeratorStorage.remove(); + viewNumeratorStorage.remove(); + viewDenominatorStorage.remove(); + } + + /** + * 准备模型 + */ + private void prepareModel() { + probabilityNumerator = DenseVector.valueOf(numberOfStates); + probabilityDenominator = new AtomicDouble(); + stateNumerator = DenseMatrix.valueOf(numberOfStates, numberOfStates); + stateDenominator = DenseVector.valueOf(numberOfStates); + viewNumerator = DenseMatrix.valueOf(numberOfStates, itemSize); + viewDenominator = DenseVector.valueOf(numberOfStates); + + // probabilities => {numberOfStates} + probabilities = DenseVector.valueOf(numberOfStates); + // 归一化 + probabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + probabilities.scaleValues(1F / probabilities.getSum(false)); + probabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((float) Math.log(scalar.getValue())); + }); + + // stateProbabilities = {numberOfStates, numberOfStates} + stateProbabilities = DenseMatrix.valueOf(numberOfStates, numberOfStates); + // 归一化 + for (int stateIndex = 0; stateIndex < numberOfStates; stateIndex++) { + MathVector probabilities = stateProbabilities.getRowVector(stateIndex); + probabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + probabilities.scaleValues(1F / probabilities.getSum(false)); + } + stateProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((float) Math.log(scalar.getValue())); + }); + + // viewNumerator => {numberOfItems} + DenseVector viewNumerator = DenseVector.valueOf(itemSize); + // viewDenominator => sum(sizeOfContexts) + float viewDenominator = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseMatrix dataMatrix = dataMatrixes[userIndex]; + for (MatrixScalar term : dataMatrix) { + viewNumerator.shiftValue(term.getColumn(), term.getValue()); + } + viewDenominator += dataMatrix.getRowSize(); + } + // viewProbabilities => {numberOfStates, numberOfItems} + viewNumerator.scaleValues(1F / viewDenominator); + // 保证不为0且归一化 + viewNumerator.iterateElement(MathCalculator.SERIAL, (scalar) -> { + float value = scalar.getValue(); + scalar.setValue(value == 0F ? nearZero : value); + }); + viewNumerator.scaleValues(1F / viewNumerator.getSum(false)); + viewProbabilities = DenseMatrix.valueOf(numberOfStates, itemSize); + for (int stateIndex = 0; stateIndex < numberOfStates; stateIndex++) { + viewProbabilities.getRowVector(stateIndex).copyVector(viewNumerator); + } + + // numeratorPsiGamma => {numberOfStates} + numeratorPsiGamma = DenseVector.valueOf(numberOfStates); + // numeratorLogGamma=> {numberOfStates} + numeratorLogGamma = DenseVector.valueOf(numberOfStates); + // denominatorPsiGamma => {numberOfStates} + denominatorPsiGamma = DenseVector.valueOf(numberOfStates); + // denominatorLogGamma => {numberOfStates} + denominatorLogGamma = DenseVector.valueOf(numberOfStates); + // psiNumerator => {numberOfStates} + psiNumerator = DenseVector.valueOf(numberOfStates); + // nutNumerator => {numberOfStates} + nutNumerator = DenseVector.valueOf(numberOfStates); + // next_denominator => {numberOfStates} + averageDenominator = DenseVector.valueOf(numberOfStates); + // numerator => {numberOfStates} + modelNumerator = DenseVector.valueOf(numberOfStates); + // denominator => {numberOfStates} + modelDenominator = DenseVector.valueOf(numberOfStates); + + alpha = DenseVector.valueOf(numberOfStates); + beta = DenseVector.valueOf(numberOfStates); + + gammas = new DenseMatrix[userSize]; + nuts = new DenseVector[userSize]; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseMatrix dataMatrix = dataMatrixes[userIndex]; + int sizeOfContexts = dataMatrix.getRowSize(); + DenseMatrix gamma = DenseMatrix.valueOf(sizeOfContexts, numberOfStates); + gammas[userIndex] = gamma.setValues(1F); + + DenseVector nut = DenseVector.valueOf(sizeOfContexts); + nut.iterateElement(MathCalculator.SERIAL, (scalar) -> { + SparseVector dataVector = dataMatrix.getRowVector(scalar.getIndex()); + scalar.setValue(dataVector.getSum(false)); + }); + nuts[userIndex] = nut; + } + + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // 上下文维度 + contextField = configuration.getString("data.model.fields.context"); + contextDimension = model.getQualityInner(contextField); + numberOfStates = configuration.getInteger("recommender.hmm.state.number"); + probabilityRegularization = configuration.getFloat("recommender.probability.regularization", 100F); + stateRegularization = configuration.getFloat("recommender.state.regularization", 100F); + viewRegularization = configuration.getFloat("recommender.view.regularization", 100F); + + // 检查参数配置 + if (probabilityRegularization < numberOfStates || stateRegularization < numberOfStates || viewRegularization < itemSize) { + throw new IllegalArgumentException(); + } + + // 按照上下文划分数据 + dataMatrixes = new SparseMatrix[userSize]; + contextSize = 0; + + MemoryQualityAttribute attribute = (MemoryQualityAttribute) space.getQualityAttribute(contextField); + Object[] levels = attribute.getDatas(); + Table table = HashBasedTable.create(); + Table data = HashBasedTable.create(); + + DataSplitter splitter = new QualityFeatureDataSplitter(userDimension); + DataModule[] models = splitter.split(model, userSize); + DataSorter sorter = new AllFeatureDataSorter(); + for (int index = 0; index < userSize; index++) { + models[index] = sorter.sort(models[index]); + } + + for (int userIndex = 0; userIndex < userSize; userIndex++) { + DataModule module = models[userIndex]; + for (DataInstance instance : module) { + int rowKey = (Integer) levels[instance.getQualityFeature(contextDimension)]; + int columnKey = instance.getQualityFeature(itemDimension); + Float count = table.get(rowKey, columnKey); + table.put(rowKey, columnKey, count == null ? 1 : ++count); + } + + ArrayList keys = new ArrayList<>(table.rowKeySet()); + Collections.sort(keys); + int index = 0; + for (Integer key : keys) { + for (Entry term : table.row(key).entrySet()) { + data.put(index, term.getKey(), term.getValue()); + } + index++; + } + table.clear(); + + // 使用稀疏矩阵 + SparseMatrix matrix = SparseMatrix.valueOf(keys.size(), itemSize, data); + if (contextSize < matrix.getRowSize()) { + contextSize = matrix.getRowSize(); + } + dataMatrixes[userIndex] = matrix; + + data.clear(); + System.out.println(userIndex + " " + matrix.getRowSize() + " " + matrix.getColumnSize()); + } + + // 准备模型 + prepareModel(); + } + + /** + * 范数 + * + * @param vector + * @return + */ + private float calculateNorm(DenseVector vector) { + // log(sum(exp(vector))) 用于保持凸性 + // log是对数函数,exp是指数函数 + float maximum = Float.NEGATIVE_INFINITY; + for (VectorScalar term : vector) { + if (maximum < term.getValue()) { + maximum = term.getValue(); + } + } + float sum = 0F; + for (VectorScalar term : vector) { + sum += Math.exp(term.getValue() - maximum); + } + return (float) (maximum + Math.log(sum)); + } + + /** + * 计算辐射概率矩阵 + * + * @param matrix + * @return + */ + private DenseMatrix calculateEmissionProbabilities(int userIndex, SparseMatrix matrix) { + // matrix => {sizeOfContexts, numberOfItems} + + // first-fourth=> {sizeOfContexts} + int sizeOfContexts = matrix.getRowSize(); + DenseVector sumVector = nuts[userIndex]; + DenseVector sumGammaVector = DenseVector.valueOf(sizeOfContexts, sumGammaStorage.get()); + DenseVector gammaSumVector = DenseVector.valueOf(sizeOfContexts, gammaSumStorage.get()); + + for (int contextIndex = 0; contextIndex < sizeOfContexts; contextIndex++) { + sumGammaVector.setValue(contextIndex, GammaUtility.logGamma(sumVector.getValue(contextIndex) + 1F)); + + SparseVector vector = matrix.getRowVector(contextIndex); + float value = 0F; + for (VectorScalar term : vector) { + value += GammaUtility.logGamma(term.getValue() + 1F); + } + value += (GammaUtility.logGamma(1F) * (itemSize - vector.getElementSize())); + gammaSumVector.setValue(contextIndex, value); + } + + DenseVector logAlpha = DenseVector.valueOf(numberOfStates, alphaStorage.get()); + logAlpha.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(GammaUtility.logGamma(alpha.getValue(scalar.getIndex()))); + }); + DenseVector logBeta = DenseVector.valueOf(numberOfStates, betaStorage.get()); + logBeta.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((float) Math.log(beta.getValue(scalar.getIndex()) + 1F)); + }); + + // 计算负二项分布矩阵 + // binomial => {sizeOfContexts, numberOfStates} + DenseMatrix binomial = DenseMatrix.valueOf(matrix.getRowSize(), numberOfStates, binomialStorage.get()); + binomial.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = GammaUtility.logGamma(sumVector.getValue(row) + alpha.getValue(column)); + value -= sumGammaVector.getValue(row); + value += (sumVector.getValue(row) * Math.log(beta.getValue(column))); + value -= ((sumVector.getValue(row) + alpha.getValue(column)) * logBeta.getValue(column)); + value -= logAlpha.getValue(column); + scalar.setValue(value); + }); + + sumGammaVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value - gammaSumVector.getValue(index)); + }); + + // 计算多项分布矩阵 + // multinomial => {sizeOfContexts, numberOfStates} + DenseMatrix multinomial = DenseMatrix.valueOf(matrix.getRowSize(), numberOfStates, multinomialStorage.get()); + multinomial.dotProduct(matrix, false, viewProbabilities, true, MathCalculator.SERIAL); + for (int index = 0; index < multinomial.getRowSize(); index++) { + multinomial.getRowVector(index).shiftValues(sumGammaVector.getValue(index)); + } + + // emission => {sizeOfContexts, numberOfStates} + binomial.addMatrix(multinomial, false); + DenseMatrix emission = binomial; + return emission; + } + + /** + * 计算 gamma and rho. + * + * @param matrix + * @return + */ + private void calculateGammaRho(int userIndex, SparseMatrix matrix) { + // Formula 1.1 + DenseMatrix logEmission = calculateEmissionProbabilities(userIndex, matrix); + + // calculateAlphaBeta + int sizeOfContexts = logEmission.getRowSize(); + // TODO 此处按照建议,可以考虑改为一个极大的负数作为logAlpha和logBeta的初始化值.具体效果待验证. + // logAlpha => {sizeOfContexts, numberOfStates} + DenseMatrix logAlpha = DenseMatrix.valueOf(sizeOfContexts, numberOfStates, alphaStorage.get()); + logAlpha.setValues(0F); + // logBeta => {sizeOfContexts, numberOfStates} + DenseMatrix logBeta = DenseMatrix.valueOf(sizeOfContexts, numberOfStates, betaStorage.get()); + logBeta.setValues(0F); + // contextNorm => {sizeOfContexts} + DenseVector contextNorm = DenseVector.valueOf(sizeOfContexts, normContextStorage.get()); + // stateNorm => {numberOfStates} + DenseVector stateNorm = DenseVector.valueOf(numberOfStates, normStateStorage.get()); + + // Formula 1.2 + DenseVector emission = logEmission.getRowVector(0); + logAlpha.getRowVector(0).iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + scalar.setValue(probabilities.getValue(index) + emission.getValue(index)); + }); + float norm = calculateNorm(logAlpha.getRowVector(0)); + contextNorm.setValue(0, norm); + logAlpha.getRowVector(0).shiftValues(-norm); + + // Formula 1.3 + for (int context = 1; context < sizeOfContexts; context++) { + int contextIndex = context; + for (int state = 0; state < numberOfStates; state++) { + stateNorm.copyVector(stateProbabilities.getColumnVector(state)); + stateNorm.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value + logAlpha.getValue(contextIndex - 1, index)); + }); + norm = calculateNorm(stateNorm) + logEmission.getValue(context, state); + logAlpha.setValue(context, state, norm); + } + norm = calculateNorm(logAlpha.getRowVector(context)); + contextNorm.setValue(context, norm); + logAlpha.getRowVector(context).shiftValues(-norm); + } + + // Formula 1.4 + for (int context = sizeOfContexts - 2; context > -1; context--) { + int contextIndex = context; + for (int state = 0; state < numberOfStates; state++) { + stateNorm.copyVector(stateProbabilities.getRowVector(state)); + stateNorm.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value + logBeta.getValue(contextIndex + 1, index) + logEmission.getValue(contextIndex + 1, index)); + }); + norm = calculateNorm(stateNorm); + logBeta.setValue(context, state, norm); + } + logBeta.getRowVector(context).shiftValues(-contextNorm.getValue(context + 1)); + } + + // Formula 1.5 + gammas[userIndex].iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + scalar.setValue((float) (Math.exp(logAlpha.getValue(row, column) + logBeta.getValue(row, column)))); + }); + + // Formula 1.6 + logEmission.addMatrix(logBeta, false); + DenseMatrix rho = DenseMatrix.valueOf(numberOfStates, numberOfStates, rhoStorage.get()); + DenseMatrix stateNumeratorCache = stateNumeratorStorage.get(); + for (int context = 0; context < sizeOfContexts - 1; context++) { + int contextIndex = context; + for (int state = 0; state < numberOfStates; state++) { + int stateIndex = state; + rho.getRowVector(state).iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + scalar.setValue(stateProbabilities.getValue(stateIndex, index) + logEmission.getValue(contextIndex + 1, index) + logAlpha.getValue(contextIndex, stateIndex)); + }); + } + float normValue = contextNorm.getValue(context + 1); + rho.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((float) Math.exp(scalar.getValue() - normValue)); + }); + stateNumeratorCache.addMatrix(rho, false); + } + } + + @Override + protected boolean isConverged(int iterationStep) { + // calculate the expected likelihood. + // Formula 1.12 + float likelihood = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + // gamma => {sizeOfContexts, numberOfStates} + DenseMatrix gamma = gammas[userIndex]; + if (gamma.getRowSize() == 0) { + // 处理用户sizeOfContexts == 0的情况 + continue; + } + float probability = 0F; + for (int stateIndex = 0; stateIndex < numberOfStates; stateIndex++) { + probability += (gamma.getValue(0, stateIndex) * probabilities.getValue(stateIndex)); + } + + // // rho => {sizeOfContexts - 1, numberOfStates, numberOfStates} + // DenseMatrix rho = logRhos[userIndex]; + // double state = rho.calculate((row, column, value) -> { + // return value * stateProbabilities.getTermValue(row, column); + // }).sum(); + + // // binomial => {sizeOfContexts, numberOfStates} + // DenseMatrix binomial = + // calculateNegativeBinomial(dataMatrixes[userIndex]); + // double negative = binomial.calculate((row, column, value) -> { + // return value * gamma.getTermValue(row, column); + // }).sum(); + // + // // multinomial => {sizeOfContexts, numberOfStates} + // DenseMatrix multinomial = + // calculateMultinomial(dataMatrixes[userIndex]); + // double positive = multinomial.calculate((row, column, value) -> { + // return value * gamma.getTermValue(row, column); + // }).sum(); + + likelihood += probability; + if (Float.isNaN(likelihood)) { + throw new IllegalStateException(); + } + } + + // 是否收敛 + System.out.println(iterationStep + " " + likelihood); + float deltaLoss = likelihood - currentError; + if (iterationStep > 1 && (deltaLoss < 0.1F)) { + return true; + } + currentError = likelihood; + return false; + } + + /** + * 计算隐马尔可夫模型 + * + * @param dataMatrixes + * @return + */ + private void calculateModel() { + EnvironmentContext context = EnvironmentContext.getContext(); + psiNumerator.setValues(0F); + nutNumerator.setValues(0F); + averageDenominator.setValues(0F); + + // TODO 此处可以并发 + for (int userIndex = 0; userIndex < userSize; userIndex++) { + DenseMatrix gamma = gammas[userIndex]; + DenseVector nut = nuts[userIndex]; + + nutNumerator.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + for (VectorScalar term : gamma.getColumnVector(index)) { + value += (term.getValue() * nut.getValue(term.getIndex())); + } + scalar.setValue(value); + }); + + psiNumerator.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + for (VectorScalar term : gamma.getColumnVector(index)) { + // TODO 减少PolyGamma.psigamma计算 + value += (term.getValue() * GammaUtility.digamma(nut.getValue(term.getIndex()))); + } + scalar.setValue(value); + }); + + averageDenominator.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value + gamma.getColumnVector(index).getSum(false)); + }); + } + + beta.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(nutNumerator.getValue(index) / averageDenominator.getValue(index)); + }); + alpha.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue((float) (0.5F / (Math.log(beta.getValue(index)) - psiNumerator.getValue(index) / averageDenominator.getValue(index)))); + }); + + for (int sampleIndex = 0; sampleIndex < sampleSize; sampleIndex++) { + modelNumerator.setValues(0F); + modelDenominator.setValues(0F); + { + context.doAlgorithmByEvery(() -> { + numeratorStorage.get().setValues(0F); + denominatorStorage.get().setValues(0F); + }); + } + + for (int stateIndex = 0; stateIndex < numberOfStates; stateIndex++) { + numeratorPsiGamma.setValue(stateIndex, GammaUtility.digamma(alpha.getValue(stateIndex))); + numeratorLogGamma.setValue(stateIndex, (float) (Math.log(beta.getValue(stateIndex) / alpha.getValue(stateIndex) + 1F))); + denominatorPsiGamma.setValue(stateIndex, GammaUtility.trigamma(alpha.getValue(stateIndex))); + denominatorLogGamma.setValue(stateIndex, (float) (1F / (beta.getValue(stateIndex) + alpha.getValue(stateIndex)))); + } + + { + // 并发计算 + CountDownLatch latch = new CountDownLatch(userSize); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + DenseMatrix gamma = gammas[userIndex]; + DenseVector nut = nuts[userIndex]; + context.doAlgorithmByAny(userIndex, () -> { + // numeratorMatrix => {sizeOfContexts, numberOfStates} + DenseVector numeratorCache = numeratorStorage.get(); + // denominatorMatrix => {sizeOfContexts, numberOfStates} + DenseVector denominatorCache = denominatorStorage.get(); + for (MatrixScalar term : gamma) { + int row = term.getRow(); + int column = term.getColumn(); + numeratorCache.shiftValue(column, term.getValue() * (GammaUtility.digamma(nut.getValue(row) + alpha.getValue(column)) - numeratorPsiGamma.getValue(column) - numeratorLogGamma.getValue(column))); + denominatorCache.shiftValue(column, term.getValue() * (GammaUtility.trigamma(nut.getValue(row) + alpha.getValue(column)) - denominatorPsiGamma.getValue(column) - denominatorLogGamma.getValue(column) + (1F / alpha.getValue(column)))); + } + latch.countDown(); + }); + } + try { + latch.await(); + } catch (Exception exception) { + throw new ModelException(exception); + } + } + + { + context.doAlgorithmByEvery(() -> { + synchronized (modelNumerator) { + modelNumerator.addVector(numeratorStorage.get()); + } + synchronized (modelDenominator) { + modelDenominator.addVector(denominatorStorage.get()); + } + }); + } + + // TODO 此处相当于学习率 + modelDenominator.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + value = (value == 0D ? nearZero : value); + scalar.setValue(modelNumerator.getValue(index) / value); + }); + boolean isBreak = false; + for (VectorScalar term : alpha) { + if (term.getValue() <= modelDenominator.getValue(term.getIndex())) { + isBreak = true; + break; + } + } + if (isBreak) { + break; + } + alpha.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value - modelDenominator.getValue(index)); + }); + if (!checkModel(alpha)) { + throw new IllegalStateException(); + } + } + + beta.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value / alpha.getValue(index)); + }); + } + + @Override + protected void doPractice() { + calculateModel(); + + super.doPractice(); + + // Formula 1.13 and Formula 1.14 + norms = DenseMatrix.valueOf(userSize, numberOfStates); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + gammas[userIndex].iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((float) Math.log(scalar.getValue())); + }); + for (int stateIndex = 0; stateIndex < numberOfStates; stateIndex++) { + int state = stateIndex; + // gamma => {numberOfStates} + DenseMatrix gamma = gammas[userIndex]; + // 处理用户sizeOfContexts == 0的情况 + DenseVector norm = DenseVector.copyOf(gamma.getRowSize() == 0 ? probabilities : gamma.getRowVector(gamma.getRowSize() - 1)); + norm.addVector(stateProbabilities.getColumnVector(state)); + norms.setValue(userIndex, stateIndex, calculateNorm(norm)); + } + } + + viewProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue((float) (Math.log(Math.exp(value) * beta.getValue(row) + 1D) * alpha.getValue(row))); + }); + } + + @Override + protected void eStep() { + EnvironmentContext context = EnvironmentContext.getContext(); + // 并发计算 + CountDownLatch latch = new CountDownLatch(userSize); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + int user = userIndex; + context.doAlgorithmByAny(userIndex, () -> { + calculateGammaRho(user, dataMatrixes[user]); + latch.countDown(); + }); + } + try { + latch.await(); + } catch (Exception exception) { + throw new ModelException(exception); + } + } + + @Override + protected void mStep() { + EnvironmentContext context = EnvironmentContext.getContext(); + probabilityNumerator.setValues(0F); + probabilityDenominator.set(0D); + stateNumerator.setValues(0F); + viewNumerator.setValues(0F); + viewDenominator.setValues(0F); + + { + context.doAlgorithmByEvery(() -> { + probabilityNumeratorStorage.get().setValues(0F); + probabilityDenominatorStorage.get().set(0D); + stateNumerator.addMatrix(stateNumeratorStorage.get(), false); + stateNumeratorStorage.get().setValues(0F); + viewNumeratorStorage.get().setValues(0F); + viewDenominatorStorage.get().setValues(0F); + }); + } + + stateDenominator.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + scalar.setValue(stateNumerator.getRowVector(index).getSum(false)); + }); + + { + // 并发计算 + CountDownLatch latch = new CountDownLatch(userSize); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + // gamma => {sizeOfContexts, numberOfStates} + DenseMatrix gamma = gammas[userIndex]; + if (gamma.getRowSize() == 0) { + // 处理用户sizeOfContexts == 0的情况 + latch.countDown(); + continue; + } + DenseVector nut = nuts[userIndex]; + SparseMatrix dataMatrix = dataMatrixes[userIndex]; + context.doAlgorithmByAny(userIndex, () -> { + MathVector gammaVector = gamma.getRowVector(0); + probabilityNumeratorStorage.get().addVector(gammaVector); + probabilityDenominatorStorage.get().addAndGet(gammaVector.getSum(false)); + // viewNumerator => {numberOfStates, numberOfItems} + // 利用稀疏矩阵减少计算. + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + if (dataMatrix.getColumnScope(itemIndex) > 0) { + SparseVector itemVector = dataMatrix.getColumnVector(itemIndex); + viewNumeratorStorage.get().getColumnVector(itemIndex).iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + for (VectorScalar term : itemVector) { + value += (gamma.getValue(term.getIndex(), index) * term.getValue()); + } + scalar.setValue(value); + }); + } + } + // viewDenominator => {numberOfStates} + viewDenominatorStorage.get().iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + for (VectorScalar term : nut) { + value += (gamma.getValue(term.getIndex(), index) * term.getValue()); + } + scalar.setValue(value); + }); + latch.countDown(); + }); + } + try { + latch.await(); + } catch (Exception exception) { + throw new ModelException(exception); + } + } + + { + context.doAlgorithmByEvery(() -> { + synchronized (probabilityNumerator) { + probabilityNumerator.addVector(probabilityNumeratorStorage.get()); + } + synchronized (probabilityDenominator) { + probabilityDenominator.addAndGet(probabilityDenominatorStorage.get().get()); + } + synchronized (viewNumerator) { + viewNumerator.addMatrix(viewNumeratorStorage.get(), false); + } + synchronized (viewDenominator) { + viewDenominator.addVector(viewDenominatorStorage.get()); + } + }); + } + + // Formula 1.7 + probabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = probabilityNumerator.getValue(index) + (probabilityRegularization / numberOfStates - 1F); + value = (float) (value / (probabilityDenominator.get() + probabilityRegularization - numberOfStates)); + value = (float) Math.log(value); + scalar.setValue(value); + }); + + // Formula 1.8 + stateProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = stateNumerator.getValue(row, column) + (stateRegularization / numberOfStates - 1F); + value = (float) (value / (stateDenominator.getValue(row) + stateRegularization - numberOfStates)); + value = (float) Math.log(value); + scalar.setValue(value); + }); + + // Formula 1.9 + viewProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = viewNumerator.getValue(row, column) + (viewRegularization / itemSize - 1F); + value = (float) (value / (viewDenominator.getValue(row) + viewRegularization - itemSize)); + value = (float) Math.log(value); + scalar.setValue(value); + }); + + // Formula 1.10 and Formula 1.11 + calculateModel(); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float score = 0F; + for (int state = 0; state < numberOfStates; state++) { + score += Math.exp(norms.getValue(userIndex, state) - viewProbabilities.getValue(state, itemIndex)); + } + score = 1F - score; + instance.setQuantityMark(score); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/ItemBigramModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/ItemBigramModel.java new file mode 100644 index 0000000..c83e5ef --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/ItemBigramModel.java @@ -0,0 +1,361 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.core.utility.StringUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; +import com.jstarcraft.rns.utility.GammaUtility; +import com.jstarcraft.rns.utility.SampleUtility; + +import it.unimi.dsi.fastutil.ints.Int2IntRBTreeMap; + +/** + * + * Item Bigram推荐器 + * + *
+ * Topic modeling: beyond bag-of-words
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class ItemBigramModel extends ProbabilisticGraphicalModel { + + /** 上下文字段 */ + private String instantField; + + /** 上下文维度 */ + private int instantDimension; + + private Map> userItemMap; + + /** + * k: current topic; j: previously rated item; i: current item + */ + private int[][][] topicItemBigramTimes; + private DenseMatrix topicItemProbabilities; + private float[][][] topicItemBigramProbabilities, topicItemBigramSums; + + private DenseMatrix beta; + + /** + * vector of hyperparameters for alpha + */ + private DenseVector alpha; + + /** + * Dirichlet hyper-parameters of user-topic distribution: typical value is 50/K + */ + private float initAlpha; + + /** + * Dirichlet hyper-parameters of topic-item distribution, typical value is 0.01 + */ + private float initBeta; + + /** + * cumulative statistics of theta, phi + */ + private DenseMatrix userTopicSums; + + /** + * entry[u, k]: number of tokens assigned to topic k, given user u. + */ + private DenseMatrix userTopicTimes; + + /** + * entry[u]: number of tokens rated by user u. + */ + private DenseVector userTokenNumbers; + + /** + * posterior probabilities of parameters + */ + private DenseMatrix userTopicProbabilities; + + /** + * entry[u, i, k]: topic assignment as sparse structure + */ + // TODO 考虑DenseMatrix支持Integer类型 + private Int2IntRBTreeMap topicAssignments; + + private DenseVector randomProbabilities; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + initAlpha = configuration.getFloat("recommender.user.dirichlet.prior", 0.01F); + initBeta = configuration.getFloat("recommender.topic.dirichlet.prior", 0.01F); + + instantField = configuration.getString("data.model.fields.instant"); + instantDimension = model.getQualityInner(instantField); + Int2IntRBTreeMap instantTabel = new Int2IntRBTreeMap(); + instantTabel.defaultReturnValue(-1); + for (DataInstance sample : model) { + int instant = instantTabel.get(sample.getQualityFeature(userDimension) * itemSize + sample.getQualityFeature(itemDimension)); + if (instant == -1) { + instant = sample.getQualityFeature(instantDimension); + } else { + instant = sample.getQualityFeature(instantDimension) > instant ? sample.getQualityFeature(instantDimension) : instant; + } + instantTabel.put(sample.getQualityFeature(userDimension) * itemSize + sample.getQualityFeature(itemDimension), instant); + } + // build the training data, sorting by date + userItemMap = new HashMap<>(); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + // TODO 考虑优化 + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0) { + continue; + } + + // 按照时间排序 + List> instants = new ArrayList<>(userVector.getElementSize()); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + instants.add(new KeyValue<>(itemIndex, instantTabel.get(userIndex * itemSize + itemIndex))); + } + Collections.sort(instants, (left, right) -> { + // 升序 + return left.getValue().compareTo(right.getValue()); + }); + List items = new ArrayList<>(userVector.getElementSize()); + for (KeyValue term : instants) { + items.add(term.getKey()); + } + + userItemMap.put(userIndex, items); + } + + // count variables + // initialize count variables. + userTopicTimes = DenseMatrix.valueOf(userSize, factorSize); + userTokenNumbers = DenseVector.valueOf(userSize); + + // 注意:numItems + 1的最后一个元素代表没有任何记录的概率 + topicItemBigramTimes = new int[factorSize][itemSize + 1][itemSize]; + topicItemProbabilities = DenseMatrix.valueOf(factorSize, itemSize + 1); + + // Logs.debug("topicPreItemCurItemNum consumes {} bytes", + // Strings.toString(Memory.bytes(topicPreItemCurItemNum))); + + // parameters + userTopicSums = DenseMatrix.valueOf(userSize, factorSize); + topicItemBigramSums = new float[factorSize][itemSize + 1][itemSize]; + topicItemBigramProbabilities = new float[factorSize][itemSize + 1][itemSize]; + + // hyper-parameters + alpha = DenseVector.valueOf(factorSize); + alpha.setValues(initAlpha); + + beta = DenseMatrix.valueOf(factorSize, itemSize + 1); + beta.setValues(initBeta); + + // initialization + topicAssignments = new Int2IntRBTreeMap(); + for (Entry> term : userItemMap.entrySet()) { + int userIndex = term.getKey(); + List items = term.getValue(); + + for (int index = 0; index < items.size(); index++) { + int nextItemIndex = items.get(index); + // TODO 需要重构 + int topicIndex = RandomUtility.randomInteger(factorSize); + topicAssignments.put(userIndex * itemSize + nextItemIndex, topicIndex); + + userTopicTimes.shiftValue(userIndex, topicIndex, 1F); + userTokenNumbers.shiftValue(userIndex, 1F); + + int previousItemIndex = index > 0 ? items.get(index - 1) : itemSize; + topicItemBigramTimes[topicIndex][previousItemIndex][nextItemIndex]++; + topicItemProbabilities.shiftValue(topicIndex, previousItemIndex, 1F); + } + } + + randomProbabilities = DenseVector.valueOf(factorSize); + } + + @Override + protected void eStep() { + float sumAlpha = alpha.getSum(false); + DenseVector topicVector = DenseVector.valueOf(factorSize); + topicVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(beta.getRowVector(scalar.getIndex()).getSum(false)); + }); + + for (Entry> term : userItemMap.entrySet()) { + int userIndex = term.getKey(); + List items = term.getValue(); + + for (int index = 0; index < items.size(); index++) { + int nextItemIndex = items.get(index); + int assignmentIndex = topicAssignments.get(userIndex * itemSize + nextItemIndex); + + userTopicTimes.shiftValue(userIndex, assignmentIndex, -1F); + userTokenNumbers.shiftValue(userIndex, -1F); + + int previousItemIndex = index > 0 ? items.get(index - 1) : itemSize; + topicItemBigramTimes[assignmentIndex][previousItemIndex][nextItemIndex]--; + topicItemProbabilities.shiftValue(assignmentIndex, previousItemIndex, -1F); + + // 计算概率 + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + randomProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int topicIndex = scalar.getIndex(); + float userProbability = (userTopicTimes.getValue(userIndex, assignmentIndex) + alpha.getValue(topicIndex)) / (userTokenNumbers.getValue(userIndex) + sumAlpha); + float topicProbability = (topicItemBigramTimes[topicIndex][previousItemIndex][nextItemIndex] + beta.getValue(topicIndex, previousItemIndex)) / (topicItemProbabilities.getValue(topicIndex, previousItemIndex) + topicVector.getValue(topicIndex)); + float value = userProbability * topicProbability; + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + + int randomIndex = SampleUtility.binarySearch(randomProbabilities, 0, randomProbabilities.getElementSize() - 1, RandomUtility.randomFloat(sum.getValue())); + topicAssignments.put(userIndex * itemSize + nextItemIndex, randomIndex); + userTopicTimes.shiftValue(userIndex, randomIndex, 1F); + userTokenNumbers.shiftValue(userIndex, 1F); + topicItemBigramTimes[randomIndex][previousItemIndex][nextItemIndex]++; + topicItemProbabilities.shiftValue(randomIndex, previousItemIndex, 1F); + } + } + } + + @Override + protected void mStep() { + float denominator = 0F; + float value = 0F; + + float alphaSum = alpha.getSum(false); + float alphaDigamma = GammaUtility.digamma(alphaSum); + float alphaValue; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + // TODO 应该修改为稀疏向量 + value = userTokenNumbers.getValue(userIndex); + if (value != 0F) { + denominator += GammaUtility.digamma(value + alphaSum) - alphaDigamma; + } + } + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + alphaValue = alpha.getValue(topicIndex); + alphaDigamma = GammaUtility.digamma(alphaValue); + float numerator = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + // TODO 应该修改为稀疏矩阵 + value = userTopicTimes.getValue(userIndex, topicIndex); + if (value != 0F) { + numerator += GammaUtility.digamma(value + alphaValue) - alphaDigamma; + } + } + if (numerator != 0D) { + alpha.setValue(topicIndex, alphaValue * (numerator / denominator)); + } + } + + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float betaSum = beta.getRowVector(topicIndex).getSum(false); + float betaDigamma = GammaUtility.digamma(betaSum); + float betaValue; + float[] denominators = new float[itemSize + 1]; + for (int itemIndex = 0; itemIndex < itemSize + 1; itemIndex++) { + // TODO 应该修改为稀疏矩阵 + value = topicItemProbabilities.getValue(topicIndex, itemIndex); + if (value != 0F) { + denominators[itemIndex] = GammaUtility.digamma(value + betaSum) - betaDigamma; + } + } + for (int previousItemIndex = 0; previousItemIndex < itemSize + 1; previousItemIndex++) { + betaValue = beta.getValue(topicIndex, previousItemIndex); + betaDigamma = GammaUtility.digamma(betaValue); + float numerator = 0F; + denominator = 0F; + for (int nextItemIndex = 0; nextItemIndex < itemSize; nextItemIndex++) { + // TODO 应该修改为稀疏张量 + value = topicItemBigramTimes[topicIndex][previousItemIndex][nextItemIndex]; + if (value != 0F) { + numerator += GammaUtility.digamma(value + betaValue) - betaDigamma; + } + denominator += denominators[previousItemIndex]; + } + if (numerator != 0F) { + beta.setValue(topicIndex, previousItemIndex, betaValue * (numerator / denominator)); + } + } + } + } + + @Override + protected void readoutParameters() { + float value; + float sumAlpha = alpha.getSum(false); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = (userTopicTimes.getValue(userIndex, topicIndex) + alpha.getValue(topicIndex)) / (userTokenNumbers.getValue(userIndex) + sumAlpha); + userTopicSums.shiftValue(userIndex, topicIndex, value); + } + } + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float betaTopicValue = beta.getRowVector(topicIndex).getSum(false); + for (int previousItemIndex = 0; previousItemIndex < itemSize + 1; previousItemIndex++) { + for (int nextItemIndex = 0; nextItemIndex < itemSize; nextItemIndex++) { + value = (topicItemBigramTimes[topicIndex][previousItemIndex][nextItemIndex] + beta.getValue(topicIndex, previousItemIndex)) / (topicItemProbabilities.getValue(topicIndex, previousItemIndex) + betaTopicValue); + topicItemBigramSums[topicIndex][previousItemIndex][nextItemIndex] += value; + } + } + } + if (logger.isInfoEnabled()) { + String message = StringUtility.format("sumAlpha is {}", sumAlpha); + logger.info(message); + } + numberOfStatistics++; + } + + @Override + protected void estimateParameters() { + userTopicProbabilities = DenseMatrix.copyOf(userTopicSums); + userTopicProbabilities.scaleValues(1F / numberOfStatistics); + + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + for (int previousItemIndex = 0; previousItemIndex < itemSize + 1; previousItemIndex++) { + for (int nextItemIndex = 0; nextItemIndex < itemSize; nextItemIndex++) { + topicItemBigramProbabilities[topicIndex][previousItemIndex][nextItemIndex] = topicItemBigramSums[topicIndex][previousItemIndex][nextItemIndex] / numberOfStatistics; + } + } + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + List items = userItemMap.get(userIndex); + int scoreIndex = items == null ? itemSize : items.get(items.size() - 1); // last + // rated + // item + float value = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value += userTopicProbabilities.getValue(userIndex, topicIndex) * topicItemBigramProbabilities[topicIndex][scoreIndex][itemIndex]; + } + + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/ItemKNNRankingModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/ItemKNNRankingModel.java new file mode 100644 index 0000000..ca3ad0b --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/ItemKNNRankingModel.java @@ -0,0 +1,76 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Iterator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.rns.model.collaborative.ItemKNNModel; + +/** + * + * Item KNN推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class ItemKNNRankingModel extends ItemKNNModel { + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + SparseVector userVector = userVectors[userIndex]; + MathVector neighbors = itemNeighbors[itemIndex]; + if (userVector.getElementSize() == 0 || neighbors.getElementSize() == 0) { + instance.setQuantityMark(0F); + return; + } + + float sum = 0F, absolute = 0F; + int count = 0; + int leftCursor = 0, rightCursor = 0, leftSize = userVector.getElementSize(), rightSize = neighbors.getElementSize(); + Iterator leftIterator = userVector.iterator(); + VectorScalar leftTerm = leftIterator.next(); + Iterator rightIterator = neighbors.iterator(); + VectorScalar rightTerm = rightIterator.next(); + // 判断两个有序数组中是否存在相同的数字 + while (leftCursor < leftSize && rightCursor < rightSize) { + if (leftTerm.getIndex() == rightTerm.getIndex()) { + count++; + sum += rightTerm.getValue(); + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + leftCursor++; + rightCursor++; + } else if (leftTerm.getIndex() > rightTerm.getIndex()) { + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + rightCursor++; + } else if (leftTerm.getIndex() < rightTerm.getIndex()) { + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + leftCursor++; + } + } + + if (count == 0) { + instance.setQuantityMark(0F); + return; + } + + instance.setQuantityMark(sum); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LDAModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LDAModel.java new file mode 100644 index 0000000..20e6bae --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LDAModel.java @@ -0,0 +1,294 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.ArrayList; +import java.util.List; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; +import com.jstarcraft.rns.model.exception.ModelException; +import com.jstarcraft.rns.utility.GammaUtility; +import com.jstarcraft.rns.utility.SampleUtility; + +/** + * + * LDA推荐器 + * + *
+ * Latent Dirichlet Allocation for implicit feedback
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class LDAModel extends ProbabilisticGraphicalModel { + + /** + * entry[k, i]: number of tokens assigned to topic k, given item i. + */ + private DenseMatrix topicItemNumbers; + + /** + * entry[u, k]: number of tokens assigned to topic k, given user u. + */ + private DenseMatrix userTopicNumbers; + + /** + * topic assignment as list from the iterator of trainMatrix + */ + private List topicAssignments; + + /** + * entry[u]: number of tokens rated by user u. + */ + private DenseVector userTokenNumbers; + + /** + * entry[k]: number of tokens assigned to topic t. + */ + private DenseVector topicTokenNumbers; + + /** + * vector of hyperparameters for alpha and beta + */ + private DenseVector alpha, beta; + + /** + * cumulative statistics of theta, phi + */ + private DenseMatrix userTopicSums, topicItemSums; + + /** + * posterior probabilities of parameters + */ + private DenseMatrix userTopicProbabilities, topicItemProbabilities; + + private DenseVector sampleProbabilities; + + /** + * setup init member method + * + * @throws ModelException if error occurs + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + // TODO 此处代码可以消除(使用常量Marker代替或者使用binarize.threshold) + for (MatrixScalar term : scoreMatrix) { + term.setValue(1F); + } + + userTopicSums = DenseMatrix.valueOf(userSize, factorSize); + topicItemSums = DenseMatrix.valueOf(factorSize, itemSize); + + // initialize count variables. + userTopicNumbers = DenseMatrix.valueOf(userSize, factorSize); + userTokenNumbers = DenseVector.valueOf(userSize); + + topicItemNumbers = DenseMatrix.valueOf(factorSize, itemSize); + topicTokenNumbers = DenseVector.valueOf(factorSize); + + // default value: + // homas L Griffiths and Mark Steyvers. Finding scientific topics. + // Proceedings of the National Academy of Sciences, 101(suppl + // 1):5228–5235, 2004. + /** + * Dirichlet hyper-parameters of user-topic distribution: typical value is 50/K + */ + float initAlpha = configuration.getFloat("recommender.user.dirichlet.prior", 50F / factorSize); + /** + * Dirichlet hyper-parameters of topic-item distribution, typical value is 0.01 + */ + float initBeta = configuration.getFloat("recommender.topic.dirichlet.prior", 0.01F); + alpha = DenseVector.valueOf(factorSize); + alpha.setValues(initAlpha); + + beta = DenseVector.valueOf(itemSize); + beta.setValues(initBeta); + + // The z_u,i are initialized to values in [0, K-1] to determine the + // initial state of the Markov chain. + topicAssignments = new ArrayList<>(scoreMatrix.getElementSize()); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + int times = (int) (term.getValue()); + for (int time = 0; time < times; time++) { + int topicIndex = RandomUtility.randomInteger(factorSize); // 0 + // ~ + // k-1 + + // assign a topic t to pair (u, i) + topicAssignments.add(topicIndex); + // number of items of user u assigned to topic t. + userTopicNumbers.shiftValue(userIndex, topicIndex, 1F); + // total number of items of user u + userTokenNumbers.shiftValue(userIndex, 1F); + // number of instances of item i assigned to topic t + topicItemNumbers.shiftValue(topicIndex, itemIndex, 1F); + // total number of words assigned to topic t. + topicTokenNumbers.shiftValue(topicIndex, 1F); + } + } + + sampleProbabilities = DenseVector.valueOf(factorSize); + } + + @Override + protected void eStep() { + float sumAlpha = alpha.getSum(false); + float sumBeta = beta.getSum(false); + + // Gibbs sampling from full conditional distribution + int assignmentsIndex = 0; + + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + int times = (int) (term.getValue()); + for (int time = 0; time < times; time++) { + int topicIndex = topicAssignments.get(assignmentsIndex); // topic + + userTopicNumbers.shiftValue(userIndex, topicIndex, -1F); + userTokenNumbers.shiftValue(userIndex, -1F); + topicItemNumbers.shiftValue(topicIndex, itemIndex, -1F); + topicTokenNumbers.shiftValue(topicIndex, -1F); + + // 计算概率 + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + sampleProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = (userTopicNumbers.getValue(userIndex, index) + alpha.getValue(index)) / (userTokenNumbers.getValue(userIndex) + sumAlpha) * (topicItemNumbers.getValue(index, itemIndex) + beta.getValue(itemIndex)) / (topicTokenNumbers.getValue(index) + sumBeta); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + + // scaled sample because of unnormalized p[], randomly sampled a + // new topic t + topicIndex = SampleUtility.binarySearch(sampleProbabilities, 0, sampleProbabilities.getElementSize() - 1, RandomUtility.randomFloat(sum.getValue())); + + // add newly estimated z_i to count variables + userTopicNumbers.shiftValue(userIndex, topicIndex, 1F); + userTokenNumbers.shiftValue(userIndex, 1F); + topicItemNumbers.shiftValue(topicIndex, itemIndex, 1F); + topicTokenNumbers.shiftValue(topicIndex, 1F); + + topicAssignments.set(assignmentsIndex, topicIndex); + assignmentsIndex++; + } + } + } + + @Override + protected void mStep() { + float denominator; + float value; + + // update alpha vector + float alphaSum = alpha.getSum(false); + float alphaDigamma = GammaUtility.digamma(alphaSum); + float alphaValue; + denominator = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + value = userTokenNumbers.getValue(userIndex); + if (value != 0F) { + denominator += GammaUtility.digamma(value + alphaSum) - alphaDigamma; + } + } + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + alphaValue = alpha.getValue(topicIndex); + alphaDigamma = GammaUtility.digamma(alphaValue); + float numerator = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + value = userTopicNumbers.getValue(userIndex, topicIndex); + if (value != 0F) { + numerator += GammaUtility.digamma(value + alphaValue) - alphaDigamma; + } + } + if (numerator != 0F) { + alpha.setValue(topicIndex, alphaValue * (numerator / denominator)); + } + } + + // update beta vector + float betaSum = beta.getSum(false); + float betaDigamma = GammaUtility.digamma(betaSum); + float betaValue; + denominator = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = topicTokenNumbers.getValue(topicIndex); + if (value != 0F) { + denominator += GammaUtility.digamma(value + betaSum) - betaDigamma; + } + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + betaValue = beta.getValue(itemIndex); + betaDigamma = GammaUtility.digamma(betaValue); + float numerator = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = topicItemNumbers.getValue(topicIndex, itemIndex); + if (value != 0F) { + numerator += GammaUtility.digamma(value + betaValue) - betaDigamma; + } + } + if (numerator != 0F) { + beta.setValue(itemIndex, betaValue * (numerator / denominator)); + } + } + } + + /** + * Add to the statistics the values of theta and phi for the current state. + */ + @Override + protected void readoutParameters() { + float sumAlpha = alpha.getSum(false); + float sumBeta = beta.getSum(false); + float value; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = (userTopicNumbers.getValue(userIndex, topicIndex) + alpha.getValue(topicIndex)) / (userTokenNumbers.getValue(userIndex) + sumAlpha); + userTopicSums.shiftValue(userIndex, topicIndex, value); + } + } + + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + value = (topicItemNumbers.getValue(topicIndex, itemIndex) + beta.getValue(itemIndex)) / (topicTokenNumbers.getValue(topicIndex) + sumBeta); + topicItemSums.shiftValue(topicIndex, itemIndex, value); + } + } + numberOfStatistics++; + } + + @Override + protected void estimateParameters() { + float scale = 1F / numberOfStatistics; + userTopicProbabilities = DenseMatrix.copyOf(userTopicSums); + userTopicProbabilities.scaleValues(scale); + topicItemProbabilities = DenseMatrix.copyOf(topicItemSums); + topicItemProbabilities.scaleValues(scale); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userTopicProbabilities.getRowVector(userIndex); + DenseVector itemVector = topicItemProbabilities.getColumnVector(itemIndex); + instance.setQuantityMark(scalar.dotProduct(userVector, itemVector).getValue()); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMDynamicModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMDynamicModel.java new file mode 100644 index 0000000..7c60d76 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMDynamicModel.java @@ -0,0 +1,128 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Arrays; +import java.util.Comparator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.module.ArrayInstance; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.utility.LogisticUtility; +import com.jstarcraft.rns.utility.SampleUtility; + +/** + * + * Lambda FM推荐器 + * + *
+ * LambdaFM: Learning Optimal Ranking with Factorization Machines Using Lambda Surrogates
+ * 
+ * + * @author Birdy + * + */ +public class LambdaFMDynamicModel extends LambdaFMModel { + + // Dynamic + private float dynamicRho; + + private int numberOfOrders; + + private DenseVector orderProbabilities; + + private ArrayInstance[] negatives; + + private Integer[] orderIndexes; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + dynamicRho = configuration.getFloat("recommender.item.distribution.parameter"); + numberOfOrders = configuration.getInteger("recommender.number.orders", 10); + + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + orderProbabilities = DenseVector.valueOf(numberOfOrders); + orderProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = (float) (Math.exp(-(index + 1) / (numberOfOrders * dynamicRho))); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + negatives = new ArrayInstance[numberOfOrders]; + orderIndexes = new Integer[numberOfOrders]; + for (int index = 0; index < numberOfOrders; index++) { + negatives[index] = new ArrayInstance(model.getQualityOrder(), model.getQuantityOrder()); + orderIndexes[index] = index; + } + } + + @Override + protected float getGradientValue(DataModule[] modules, ArrayInstance positive, ArrayInstance negative, DefaultScalar scalar) { + int userIndex; + while (true) { + userIndex = RandomUtility.randomInteger(userSize); + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0 || userVector.getElementSize() == itemSize) { + continue; + } + + DataModule module = modules[userIndex]; + DataInstance instance = module.getInstance(0); + int positivePosition = RandomUtility.randomInteger(module.getSize()); + instance.setCursor(positivePosition); + positive.copyInstance(instance); + // TODO negativeGroup.size()可能永远达不到numberOfNegatives,需要处理 + for (int orderIndex = 0; orderIndex < numberOfOrders; orderIndex++) { + int negativeItemIndex = RandomUtility.randomInteger(itemSize - userVector.getElementSize()); + for (int position = 0, size = userVector.getElementSize(); position < size; position++) { + if (negativeItemIndex >= userVector.getIndex(position)) { + negativeItemIndex++; + continue; + } + break; + } + // TODO 注意,此处为了故意制造负面特征. + int negativePosition = RandomUtility.randomInteger(module.getSize()); + instance.setCursor(negativePosition); + negatives[orderIndex].copyInstance(instance); + negatives[orderIndex].setQualityFeature(itemDimension, negativeItemIndex); + MathVector vector = getFeatureVector(negatives[orderIndex]); + negatives[orderIndex].setQuantityMark(predict(scalar, vector)); + } + + int orderIndex = SampleUtility.binarySearch(orderProbabilities, 0, orderProbabilities.getElementSize() - 1, RandomUtility.randomFloat(orderProbabilities.getValue(orderProbabilities.getElementSize() - 1))); + Arrays.sort(orderIndexes, new Comparator() { + @Override + public int compare(Integer leftIndex, Integer rightIndex) { + return (negatives[leftIndex].getQuantityMark() > negatives[rightIndex].getQuantityMark() ? -1 : (negatives[leftIndex].getQuantityMark() < negatives[rightIndex].getQuantityMark() ? 1 : 0)); + } + }); + negative = negatives[orderIndexes[orderIndex]]; + break; + } + + positiveVector = getFeatureVector(positive); + negativeVector = getFeatureVector(negative); + + float positiveScore = predict(scalar, positiveVector); + float negativeScore = predict(scalar, negativeVector); + + float error = positiveScore - negativeScore; + + // 由于pij_real默认为1,所以简化了loss的计算. + // loss += -pij_real * Math.log(pij) - (1 - pij_real) * + // Math.log(1 - pij); + totalError += (float) -Math.log(LogisticUtility.getValue(error)); + float gradient = calaculateGradientValue(lossType, error); + return gradient; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMModel.java new file mode 100644 index 0000000..8795801 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMModel.java @@ -0,0 +1,149 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Iterator; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.module.ArrayInstance; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; +import com.jstarcraft.rns.model.FactorizationMachineModel; + +/** + * + * Lambda FM推荐器 + * + *
+ * LambdaFM: Learning Optimal Ranking with Factorization Machines Using Lambda Surrogates
+ * 
+ * + * @author Birdy + * + */ +public abstract class LambdaFMModel extends FactorizationMachineModel { + + protected int lossType; + + protected MathVector positiveVector, negativeVector; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // TODO 此处代码可以消除(使用常量Marker代替或者使用binarize.threshold) + for (MatrixScalar term : scoreMatrix) { + term.setValue(1F); + } + + lossType = configuration.getInteger("losstype", 3); + + biasRegularization = configuration.getFloat("recommender.fm.regw0", 0.1F); + weightRegularization = configuration.getFloat("recommender.fm.regW", 0.1F); + factorRegularization = configuration.getFloat("recommender.fm.regF", 0.001F); + } + + protected abstract float getGradientValue(DataModule[] modules, ArrayInstance positive, ArrayInstance negative, DefaultScalar scalar); + + @Override + protected void doPractice() { + ArrayInstance positive = new ArrayInstance(marker.getQualityOrder(), marker.getQuantityOrder()); + ArrayInstance negative = new ArrayInstance(marker.getQualityOrder(), marker.getQuantityOrder()); + + DefaultScalar scalar = DefaultScalar.getInstance(); + + DataSplitter splitter = new QualityFeatureDataSplitter(userDimension); + DataModule[] modules = splitter.split(marker, userSize); + + DenseVector positiveSum = DenseVector.valueOf(factorSize); + DenseVector negativeSum = DenseVector.valueOf(factorSize); + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + long totalTime = 0; + totalError = 0F; + for (int sampleIndex = 0, sampleTimes = userSize * 50; sampleIndex < sampleTimes; sampleIndex++) { + long current = System.currentTimeMillis(); + float gradient = getGradientValue(modules, positive, negative, scalar); + totalTime += (System.currentTimeMillis() - current); + + sum(positiveVector, positiveSum); + sum(negativeVector, negativeSum); + int leftIndex = 0, rightIndex = 0; + Iterator leftIterator = positiveVector.iterator(); + Iterator rightIterator = negativeVector.iterator(); + for (int index = 0; index < marker.getQualityOrder(); index++) { + VectorScalar leftTerm = leftIterator.next(); + VectorScalar rightTerm = rightIterator.next(); + leftIndex = leftTerm.getIndex(); + rightIndex = rightTerm.getIndex(); + if (leftIndex == rightIndex) { + weightVector.shiftValue(leftIndex, learnRatio * (gradient * 0F - weightRegularization * weightVector.getValue(leftIndex))); + totalError += weightRegularization * weightVector.getValue(leftIndex) * weightVector.getValue(leftIndex); + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float positiveFactor = positiveSum.getValue(factorIndex) * leftTerm.getValue() - featureFactors.getValue(leftIndex, factorIndex) * leftTerm.getValue() * leftTerm.getValue(); + float negativeFactor = negativeSum.getValue(factorIndex) * rightTerm.getValue() - featureFactors.getValue(rightIndex, factorIndex) * rightTerm.getValue() * rightTerm.getValue(); + + featureFactors.shiftValue(leftIndex, factorIndex, learnRatio * (gradient * (positiveFactor - negativeFactor) - factorRegularization * featureFactors.getValue(leftIndex, factorIndex))); + totalError += factorRegularization * featureFactors.getValue(leftIndex, factorIndex) * featureFactors.getValue(leftIndex, factorIndex); + } + } else { + weightVector.shiftValue(leftIndex, learnRatio * (gradient * leftTerm.getValue() - weightRegularization * weightVector.getValue(leftIndex))); + totalError += weightRegularization * weightVector.getValue(leftIndex) * weightVector.getValue(leftIndex); + weightVector.shiftValue(rightIndex, learnRatio * (gradient * -rightTerm.getValue() - weightRegularization * weightVector.getValue(rightIndex))); + totalError += weightRegularization * weightVector.getValue(rightIndex) * weightVector.getValue(rightIndex); + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float positiveFactor = positiveSum.getValue(factorIndex) * leftTerm.getValue() - featureFactors.getValue(leftIndex, factorIndex) * leftTerm.getValue() * leftTerm.getValue(); + featureFactors.shiftValue(leftIndex, factorIndex, learnRatio * (gradient * positiveFactor - factorRegularization * featureFactors.getValue(leftIndex, factorIndex))); + totalError += factorRegularization * featureFactors.getValue(leftIndex, factorIndex) * featureFactors.getValue(leftIndex, factorIndex); + + float negativeFactor = negativeSum.getValue(factorIndex) * rightTerm.getValue() - featureFactors.getValue(rightIndex, factorIndex) * rightTerm.getValue() * rightTerm.getValue(); + featureFactors.shiftValue(rightIndex, factorIndex, learnRatio * (gradient * -negativeFactor - factorRegularization * featureFactors.getValue(rightIndex, factorIndex))); + totalError += factorRegularization * featureFactors.getValue(rightIndex, factorIndex) * featureFactors.getValue(rightIndex, factorIndex); + } + } + } + } + System.out.println(totalTime); + + totalError *= 0.5; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + protected void isLearned(int iteration) { + if (learnRatio < 0F) { + return; + } + if (isLearned && iteration > 1) { + learnRatio = Math.abs(currentError) > Math.abs(totalError) ? learnRatio * 1.05F : learnRatio * 0.5F; + } else if (learnDecay > 0 && learnDecay < 1) { + learnRatio *= learnDecay; + } + // limit to max-learn-rate after update + if (learnLimit > 0 && learnRatio > learnLimit) { + learnRatio = learnLimit; + } + } + + private void sum(MathVector vector, DenseVector sum) { + // TODO 考虑调整为向量操作. + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float value = 0F; + for (VectorScalar term : vector) { + value += featureFactors.getValue(term.getIndex(), factorIndex) * term.getValue(); + } + sum.setValue(factorIndex, value); + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMStaticModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMStaticModel.java new file mode 100644 index 0000000..48fa40c --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMStaticModel.java @@ -0,0 +1,130 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Arrays; +import java.util.Comparator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.module.ArrayInstance; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.utility.LogisticUtility; +import com.jstarcraft.rns.utility.SampleUtility; + +/** + * + * Lambda FM推荐器 + * + *
+ * LambdaFM: Learning Optimal Ranking with Factorization Machines Using Lambda Surrogates
+ * 
+ * + * @author Birdy + * + */ +public class LambdaFMStaticModel extends LambdaFMModel { + + // Static + private float staticRho; + protected DenseVector itemProbabilities; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + staticRho = configuration.getFloat("recommender.item.distribution.parameter"); + // calculate popularity + Integer[] orderItems = new Integer[itemSize]; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + orderItems[itemIndex] = itemIndex; + } + Arrays.sort(orderItems, new Comparator() { + @Override + public int compare(Integer leftItemIndex, Integer rightItemIndex) { + return (scoreMatrix.getColumnScope(leftItemIndex) > scoreMatrix.getColumnScope(rightItemIndex) ? -1 : (scoreMatrix.getColumnScope(leftItemIndex) < scoreMatrix.getColumnScope(rightItemIndex) ? 1 : 0)); + } + }); + Integer[] itemOrders = new Integer[itemSize]; + for (int index = 0; index < itemSize; index++) { + int itemIndex = orderItems[index]; + itemOrders[itemIndex] = index; + } + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + itemProbabilities = DenseVector.valueOf(itemSize); + itemProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = (float) Math.exp(-(itemOrders[index] + 1) / (itemSize * staticRho)); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + + for (MatrixScalar term : scoreMatrix) { + term.setValue(itemProbabilities.getValue(term.getColumn())); + } + } + + @Override + protected float getGradientValue(DataModule[] modules, ArrayInstance positive, ArrayInstance negative, DefaultScalar scalar) { + int userIndex; + while (true) { + userIndex = RandomUtility.randomInteger(userSize); + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0 || userVector.getElementSize() == itemSize) { + continue; + } + + DataModule module = modules[userIndex]; + DataInstance instance = module.getInstance(0); + int positivePosition = RandomUtility.randomInteger(module.getSize()); + instance.setCursor(positivePosition); + positive.copyInstance(instance); + + // TODO 注意,此处为了故意制造负面特征. + int negativeItemIndex = -1; + while (negativeItemIndex == -1) { + int position = SampleUtility.binarySearch(userVector, 0, userVector.getElementSize() - 1, RandomUtility.randomFloat(itemProbabilities.getValue(itemProbabilities.getElementSize() - 1))); + int low; + int high; + if (position == -1) { + low = userVector.getIndex(userVector.getElementSize() - 1); + high = itemProbabilities.getElementSize() - 1; + } else if (position == 0) { + low = 0; + high = userVector.getIndex(position); + } else { + low = userVector.getIndex(position - 1); + high = userVector.getIndex(position); + } + negativeItemIndex = SampleUtility.binarySearch(itemProbabilities, low, high, RandomUtility.randomFloat(itemProbabilities.getValue(high))); + } + int negativePosition = RandomUtility.randomInteger(module.getSize()); + ; + instance.setCursor(negativePosition); + negative.copyInstance(instance); + negative.setQualityFeature(itemDimension, negativeItemIndex); + break; + } + + positiveVector = getFeatureVector(positive); + negativeVector = getFeatureVector(negative); + + float positiveScore = predict(scalar, positiveVector); + float negativeScore = predict(scalar, negativeVector); + + float error = positiveScore - negativeScore; + + // 由于pij_real默认为1,所以简化了loss的计算. + // loss += -pij_real * Math.log(pij) - (1 - pij_real) * + // Math.log(1 - pij); + totalError += (float) -Math.log(LogisticUtility.getValue(error)); + float gradient = calaculateGradientValue(lossType, error); + return gradient; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMWeightModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMWeightModel.java new file mode 100644 index 0000000..8cfc343 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMWeightModel.java @@ -0,0 +1,102 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.module.ArrayInstance; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.utility.LogisticUtility; + +/** + * + * Lambda FM推荐器 + * + *
+ * LambdaFM: Learning Optimal Ranking with Factorization Machines Using Lambda Surrogates
+ * 
+ * + * @author Birdy + * + */ +public class LambdaFMWeightModel extends LambdaFMModel { + + // Weight + private float[] orderLosses; + private float epsilon; + private int Y, N; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + epsilon = configuration.getFloat("epsilon"); + orderLosses = new float[itemSize - 1]; + float orderLoss = 0F; + for (int orderIndex = 1; orderIndex < itemSize; orderIndex++) { + orderLoss += 1F / orderIndex; + orderLosses[orderIndex - 1] = orderLoss; + } + for (int rankIndex = 1; rankIndex < itemSize; rankIndex++) { + orderLosses[rankIndex - 1] /= orderLoss; + } + } + + @Override + protected float getGradientValue(DataModule[] modules, ArrayInstance positive, ArrayInstance negative, DefaultScalar scalar) { + int userIndex; + float positiveScore; + float negativeScore; + while (true) { + userIndex = RandomUtility.randomInteger(userSize); + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0 || userVector.getElementSize() == itemSize) { + continue; + } + + N = 0; + Y = itemSize - scoreMatrix.getRowScope(userIndex); + DataModule module = modules[userIndex]; + DataInstance instance = module.getInstance(0); + int positivePosition = RandomUtility.randomInteger(module.getSize()); + instance.setCursor(positivePosition); + positive.copyInstance(instance); + positiveVector = getFeatureVector(positive); + positiveScore = predict(scalar, positiveVector); + do { + N++; + int negativeItemIndex = RandomUtility.randomInteger(itemSize - userVector.getElementSize()); + for (int position = 0, size = userVector.getElementSize(); position < size; position++) { + if (negativeItemIndex >= userVector.getIndex(position)) { + negativeItemIndex++; + continue; + } + break; + } + // TODO 注意,此处为了故意制造负面特征. + int negativePosition = RandomUtility.randomInteger(module.getSize()); + // TODO 注意,此处为了故意制造负面特征. + instance.setCursor(negativePosition); + negative.copyInstance(instance); + negative.setQualityFeature(itemDimension, negativeItemIndex); + negativeVector = getFeatureVector(negative); + negativeScore = predict(scalar, negativeVector); + } while ((positiveScore - negativeScore > epsilon) && N < Y - 1); + break; + } + + float error = positiveScore - negativeScore; + + // 由于pij_real默认为1,所以简化了loss的计算. + // loss += -pij_real * Math.log(pij) - (1 - pij_real) * + // Math.log(1 - pij); + totalError += (float) -Math.log(LogisticUtility.getValue(error)); + float gradient = calaculateGradientValue(lossType, error); + int orderIndex = (int) ((Y - 1) / N); + float orderLoss = orderLosses[orderIndex]; + gradient = gradient * orderLoss; + return gradient; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/ListwiseMFModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/ListwiseMFModel.java new file mode 100644 index 0000000..60fbf51 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/ListwiseMFModel.java @@ -0,0 +1,80 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +/** + * + * ListwiseMF推荐器 + * + *
+ * List-wise learning to rank with matrix factorization for collaborative filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class ListwiseMFModel extends MatrixFactorizationModel { + + private DenseVector userExponentials; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + userExponentials = DenseVector.valueOf(userSize); + for (MatrixScalar matrixentry : scoreMatrix) { + int userIndex = matrixentry.getRow(); + float score = matrixentry.getValue(); + userExponentials.shiftValue(userIndex, (float) Math.exp(score)); + } + } + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0) { + continue; + } + float exponential = 0F; + for (VectorScalar term : userVector) { + exponential += Math.exp(predict(userIndex, term.getIndex())); + } + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + float score = term.getValue(); + float predict = predict(userIndex, itemIndex); + float error = (float) (Math.exp(score) / userExponentials.getValue(userIndex) - Math.log(Math.exp(predict) / exponential)) * LogisticUtility.getGradient(predict); + totalError -= error; + // update factors + for (int factorIdx = 0; factorIdx < factorSize; factorIdx++) { + float userFactor = userFactors.getValue(userIndex, factorIdx); + float itemFactor = itemFactors.getValue(itemIndex, factorIdx); + float userDelta = error * itemFactor - userRegularization * userFactor; + float itemDelta = error * userFactor - itemRegularization * itemFactor; + userFactors.shiftValue(userIndex, factorIdx, learnRatio * userDelta); + itemFactors.shiftValue(itemIndex, factorIdx, learnRatio * itemDelta); + totalError += 0.5D * userRegularization * userFactor * userFactor + 0.5D * itemRegularization * itemFactor * itemFactor; + } + } + } + + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + +} \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/PLSAModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/PLSAModel.java new file mode 100644 index 0000000..128127d --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/PLSAModel.java @@ -0,0 +1,161 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.table.SparseTable; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; + +import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap; + +/** + * + * PLSA推荐器 + * + *
+ * Latent semantic models for collaborative filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class PLSAModel extends ProbabilisticGraphicalModel { + + /** + * {user, item, {topic z, probability}} + */ + private SparseTable probabilityTensor; + + /** + * Conditional Probability: P(z|u) + */ + private DenseMatrix userTopicProbabilities, userTopicSums; + + /** + * Conditional Probability: P(i|z) + */ + private DenseMatrix topicItemProbabilities, topicItemSums; + + /** + * topic probability sum value + */ + private DenseVector topicProbabilities; + + /** + * entry[u]: number of tokens rated by user u. + */ + private DenseVector userScoreTimes; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + // TODO 此处代码可以消除(使用常量Marker代替或者使用binarize.threshold) + for (MatrixScalar term : scoreMatrix) { + term.setValue(1F); + } + + userTopicSums = DenseMatrix.valueOf(userSize, factorSize); + topicItemSums = DenseMatrix.valueOf(factorSize, itemSize); + topicProbabilities = DenseVector.valueOf(factorSize); + + userTopicProbabilities = DenseMatrix.valueOf(userSize, factorSize); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + DenseVector probabilityVector = userTopicProbabilities.getRowVector(userIndex); + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(userSize) + 1); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + + topicItemProbabilities = DenseMatrix.valueOf(factorSize, itemSize); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + DenseVector probabilityVector = topicItemProbabilities.getRowVector(topicIndex); + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(itemSize) + 1); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + + // initialize Q + + // initialize Q + probabilityTensor = new SparseTable<>(true, userSize, itemSize, new Int2ObjectRBTreeMap<>()); + userScoreTimes = DenseVector.valueOf(userSize); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + probabilityTensor.setValue(userIndex, itemIndex, DenseVector.valueOf(factorSize)); + userScoreTimes.shiftValue(userIndex, term.getValue()); + } + } + + @Override + protected void eStep() { + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + DenseVector probabilities = probabilityTensor.getValue(userIndex, itemIndex); + probabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = userTopicProbabilities.getValue(userIndex, index) * topicItemProbabilities.getValue(index, itemIndex); + scalar.setValue(value); + }); + probabilities.scaleValues(1F / probabilities.getSum(false)); + } + } + + @Override + protected void mStep() { + userTopicSums.setValues(0F); + topicItemSums.setValues(0F); + topicProbabilities.setValues(0F); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float numerator = term.getValue(); + DenseVector probabilities = probabilityTensor.getValue(userIndex, itemIndex); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float value = probabilities.getValue(topicIndex) * numerator; + userTopicSums.shiftValue(userIndex, topicIndex, value); + topicItemSums.shiftValue(topicIndex, itemIndex, value); + topicProbabilities.shiftValue(topicIndex, value); + } + } + + for (int userIndex = 0; userIndex < userSize; userIndex++) { + float denominator = userScoreTimes.getValue(userIndex); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float value = denominator > 0F ? userTopicSums.getValue(userIndex, topicIndex) / denominator : 0F; + userTopicProbabilities.setValue(userIndex, topicIndex, value); + } + } + + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float probability = topicProbabilities.getValue(topicIndex); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + float value = probability > 0F ? topicItemSums.getValue(topicIndex, itemIndex) / probability : 0F; + topicItemProbabilities.setValue(topicIndex, itemIndex, value); + } + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userTopicProbabilities.getRowVector(userIndex); + DenseVector itemVector = topicItemProbabilities.getColumnVector(itemIndex); + instance.setQuantityMark(scalar.dotProduct(userVector, itemVector).getValue()); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankALSModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankALSModel.java new file mode 100644 index 0000000..a3c3380 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankALSModel.java @@ -0,0 +1,241 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.MatrixUtility; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * Rank ALS推荐器 + * + *
+ * Alternating Least Squares for Personalized Ranking
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class RankALSModel extends MatrixFactorizationModel { + // whether support based weighting is used ($s_i=|U_i|$) or not ($s_i=1$) + private boolean weight; + + private DenseVector weightVector; + + private float sumSupport; + + // TODO 考虑重构到父类 + private List userList; + + // TODO 考虑重构到父类 + private List itemList; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + weight = configuration.getBoolean("recommender.rankals.support.weight", true); + weightVector = DenseVector.valueOf(itemSize); + sumSupport = 0; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + float supportValue = weight ? scoreMatrix.getColumnScope(itemIndex) : 1F; + weightVector.setValue(itemIndex, supportValue); + sumSupport += supportValue; + } + + userList = new LinkedList<>(); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + if (scoreMatrix.getRowVector(userIndex).getElementSize() > 0) { + userList.add(userIndex); + } + } + userList = new ArrayList<>(userList); + + itemList = new LinkedList<>(); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + if (scoreMatrix.getColumnVector(itemIndex).getElementSize() > 0) { + itemList.add(itemIndex); + } + } + itemList = new ArrayList<>(itemList); + } + + @Override + protected void doPractice() { + // 缓存特征计算,避免消耗内存 + DenseMatrix matrixCache = DenseMatrix.valueOf(factorSize, factorSize); + DenseMatrix copyCache = DenseMatrix.valueOf(factorSize, factorSize); + DenseVector vectorCache = DenseVector.valueOf(factorSize); + DenseMatrix inverseCache = DenseMatrix.valueOf(factorSize, factorSize); + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + // P step: update user vectors + // 特征权重矩阵和特征权重向量 + DenseMatrix factorWeightMatrix = DenseMatrix.valueOf(factorSize, factorSize); + DenseVector factorWeightVector = DenseVector.valueOf(factorSize); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + float weight = weightVector.getValue(itemIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + factorWeightMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue(value + (itemVector.getValue(row) * itemVector.getValue(column) * weight)); + }); + factorWeightVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value + itemVector.getValue(index) * weight); + }); + } + + // 用户特征矩阵,用户权重向量,用户评分向量,用户次数向量. + DenseMatrix userDeltas = DenseMatrix.valueOf(userSize, factorSize); + DenseVector userWeights = DenseVector.valueOf(userSize); + DenseVector userScores = DenseVector.valueOf(userSize); + DenseVector userTimes = DenseVector.valueOf(userSize); + // 根据物品特征构建用户特征 + for (int userIndex : userList) { + // for each user + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + + // TODO 此处考虑重构,尽量减少数组构建 + DenseMatrix factorValues = DenseMatrix.valueOf(factorSize, factorSize); + DenseMatrix copyValues = DenseMatrix.valueOf(factorSize, factorSize); + DenseVector rateValues = DenseVector.valueOf(factorSize); + DenseVector weightValues = DenseVector.valueOf(factorSize); + float weightSum = 0F, rateSum = 0F, timeSum = userVector.getElementSize(); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + float score = term.getValue(); + // double cui = 1; + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + factorValues.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue(value + itemVector.getValue(row) * itemVector.getValue(column)); + }); + // ratings of unrated items will be 0 + float weight = weightVector.getValue(itemIndex) * score; + float value; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + value = itemVector.getValue(factorIndex); + userDeltas.shiftValue(userIndex, factorIndex, value); + rateValues.shiftValue(factorIndex, value * score); + weightValues.shiftValue(factorIndex, value * weight); + } + + rateSum += score; + weightSum += weight; + } + + factorValues.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue((row == column ? userRegularization : 0F) + value * sumSupport - (userDeltas.getValue(userIndex, row) * factorWeightVector.getValue(column)) - (factorWeightVector.getValue(row) * userDeltas.getValue(userIndex, column)) + (factorWeightMatrix.getValue(row, column) * timeSum)); + }); + float rateScale = rateSum; + float weightScale = weightSum; + rateValues.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue((value * sumSupport - userDeltas.getValue(userIndex, index) * weightScale) - (factorWeightVector.getValue(index) * rateScale) + (weightValues.getValue(index) * timeSum)); + }); + userFactors.getRowVector(userIndex).dotProduct(MatrixUtility.inverse(factorValues, copyValues, inverseCache), false, rateValues, MathCalculator.SERIAL); + + userWeights.setValue(userIndex, weightSum); + userScores.setValue(userIndex, rateSum); + userTimes.setValue(userIndex, timeSum); + } + + // Q step: update item vectors + DenseMatrix itemFactorMatrix = DenseMatrix.valueOf(factorSize, factorSize); + DenseMatrix itemTimeMatrix = DenseMatrix.valueOf(factorSize, factorSize); + DenseVector itemFactorVector = DenseVector.valueOf(factorSize); + DenseVector factorValues = DenseVector.valueOf(factorSize); + for (int userIndex : userList) { + DenseVector userVector = userFactors.getRowVector(userIndex); + matrixCache.dotProduct(userVector, userVector, MathCalculator.SERIAL); + itemFactorMatrix.addMatrix(matrixCache, false); + itemTimeMatrix.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue(value + (matrixCache.getValue(row, column) * userTimes.getValue(userIndex))); + }); + itemFactorVector.addVector(vectorCache.dotProduct(matrixCache, false, userDeltas.getRowVector(userIndex), MathCalculator.SERIAL)); + float rateSum = userScores.getValue(userIndex); + factorValues.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value + userVector.getValue(index) * rateSum); + }); + } + + // 根据用户特征构建物品特征 + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + // for each item + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + + // TODO 此处考虑重构,尽量减少数组构建 + DenseVector rateValues = DenseVector.valueOf(factorSize); + DenseVector weightValues = DenseVector.valueOf(factorSize); + DenseVector timeValues = DenseVector.valueOf(factorSize); + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + float score = term.getValue(); + float weight = userWeights.getValue(userIndex); + float time = score * userTimes.getValue(userIndex); + float value; + DenseVector userVector = userFactors.getRowVector(userIndex); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + value = userVector.getValue(factorIndex); + rateValues.shiftValue(factorIndex, value * score); + weightValues.shiftValue(factorIndex, value * weight); + timeValues.shiftValue(factorIndex, value * time); + } + } + + float weight = weightVector.getValue(itemIndex); + vectorCache.dotProduct(itemFactorMatrix, false, factorWeightVector, MathCalculator.SERIAL); + matrixCache.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + scalar.setValue(itemFactorMatrix.getValue(row, column) * (weight + 1)); + }); + DenseVector itemValues = itemFactors.getRowVector(itemIndex); + vectorCache.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + value = value + (rateValues.getValue(index) * sumSupport) - weightValues.getValue(index) + (itemFactorVector.getValue(index) * weight) - (factorValues.getValue(index) * weight) + (timeValues.getValue(index) * weight); + value = value - scalar.dotProduct(matrixCache.getRowVector(index), itemValues).getValue(); + scalar.setValue(value); + }); + matrixCache.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue((row == column ? itemRegularization : 0F) + (value / (weight + 1)) * sumSupport + itemTimeMatrix.getValue(row, column) * weight - value); + }); + itemValues.dotProduct(MatrixUtility.inverse(matrixCache, copyCache, inverseCache), false, vectorCache, MathCalculator.SERIAL); + } + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankCDModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankCDModel.java new file mode 100644 index 0000000..52b47b5 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankCDModel.java @@ -0,0 +1,145 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * Rank CD推荐器 + * + *
+ * 
+ * + * @author Birdy + * + */ +public class RankCDModel extends MatrixFactorizationModel { + + // private float alpha; + // item confidence + + private float confidence; + + private SparseMatrix weightMatrix; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + // TODO 此处代码可以消除(使用常量Marker代替或者使用binarize.threshold) + for (MatrixScalar term : scoreMatrix) { + term.setValue(1F); + } + + confidence = configuration.getFloat("recommender.rankcd.alpha"); + weightMatrix = SparseMatrix.copyOf(scoreMatrix, false); + weightMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(1F + confidence * scalar.getValue()); + }); + } + + @Override + protected void doPractice() { + // Init caches + double[] userScores = new double[userSize]; + double[] itemScores = new double[itemSize]; + double[] userConfidences = new double[userSize]; + double[] itemConfidences = new double[itemSize]; + + // Init Sq + DenseMatrix itemDeltas = DenseMatrix.valueOf(factorSize, factorSize); + // Init Sp + DenseMatrix userDeltas = DenseMatrix.valueOf(factorSize, factorSize); + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + itemDeltas.dotProduct(itemFactors, true, itemFactors, false, MathCalculator.SERIAL); + // Step 1: update user factors; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = weightMatrix.getRowVector(userIndex); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + itemScores[itemIndex] = predict(userIndex, itemIndex); + itemConfidences[itemIndex] = term.getValue(); + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float numerator = 0F, denominator = userRegularization + itemDeltas.getValue(factorIndex, factorIndex); + // TODO 此处可以改为减法 + for (int k = 0; k < factorSize; k++) { + if (factorIndex != k) { + numerator -= userFactors.getValue(userIndex, k) * itemDeltas.getValue(factorIndex, k); + } + } + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + itemScores[itemIndex] -= userFactors.getValue(userIndex, factorIndex) * itemFactors.getValue(itemIndex, factorIndex); + numerator += (itemConfidences[itemIndex] - (itemConfidences[itemIndex] - 1) * itemScores[itemIndex]) * itemFactors.getValue(itemIndex, factorIndex); + denominator += (itemConfidences[itemIndex] - 1) * itemFactors.getValue(itemIndex, factorIndex) * itemFactors.getValue(itemIndex, factorIndex); + } + // update puf + userFactors.setValue(userIndex, factorIndex, numerator / denominator); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + itemScores[itemIndex] += userFactors.getValue(userIndex, factorIndex) * itemFactors.getValue(itemIndex, factorIndex); + } + } + } + + // Update the Sp cache + userDeltas.dotProduct(userFactors, true, userFactors, false, MathCalculator.SERIAL); + // Step 2: update item factors; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector itemVector = weightMatrix.getColumnVector(itemIndex); + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + userScores[userIndex] = predict(userIndex, itemIndex); + userConfidences[userIndex] = term.getValue(); + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float numerator = 0F, denominator = itemRegularization + userDeltas.getValue(factorIndex, factorIndex); + // TODO 此处可以改为减法 + for (int k = 0; k < factorSize; k++) { + if (factorIndex != k) { + numerator -= itemFactors.getValue(itemIndex, k) * userDeltas.getValue(k, factorIndex); + } + } + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + userScores[userIndex] -= userFactors.getValue(userIndex, factorIndex) * itemFactors.getValue(itemIndex, factorIndex); + numerator += (userConfidences[userIndex] - (userConfidences[userIndex] - 1) * userScores[userIndex]) * userFactors.getValue(userIndex, factorIndex); + denominator += (userConfidences[userIndex] - 1) * userFactors.getValue(userIndex, factorIndex) * userFactors.getValue(userIndex, factorIndex); + } + // update qif + itemFactors.setValue(itemIndex, factorIndex, numerator / denominator); + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + userScores[userIndex] += userFactors.getValue(userIndex, factorIndex) * itemFactors.getValue(itemIndex, factorIndex); + } + } + } + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + // TODO 目前没有totalLoss. + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + instance.setQuantityMark(scalar.dotProduct(userFactors.getRowVector(userIndex), itemFactors.getRowVector(itemIndex)).getValue()); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankSGDModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankSGDModel.java new file mode 100644 index 0000000..aebd7ff --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankSGDModel.java @@ -0,0 +1,99 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.List; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.SampleUtility; + +import it.unimi.dsi.fastutil.ints.IntSet; + +/** + * + * Rank SGD推荐器 + * + *
+ * Collaborative Filtering Ensemble for Ranking
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class RankSGDModel extends MatrixFactorizationModel { + // item sampling probabilities sorted ascendingly + + protected DenseVector itemProbabilities; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // compute item sampling probability + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + itemProbabilities = DenseVector.valueOf(itemSize); + itemProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float userSize = scoreMatrix.getColumnScope(index); + // sample items based on popularity + float value = (userSize + 0F) / actionSize; + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + } + + @Override + protected void doPractice() { + List userItemSet = getUserItemSet(scoreMatrix); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + // for each rated user-item (u,i) pair + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + IntSet itemSet = userItemSet.get(userIndex); + int positiveItemIndex = term.getColumn(); + float positiveScore = term.getValue(); + int negativeItemIndex = -1; + + do { + // draw an item j with probability proportional to + // popularity + negativeItemIndex = SampleUtility.binarySearch(itemProbabilities, 0, itemProbabilities.getElementSize() - 1, RandomUtility.randomFloat(itemProbabilities.getValue(itemProbabilities.getElementSize() - 1))); + // ensure that it is unrated by user u + } while (itemSet.contains(negativeItemIndex)); + + float negativeScore = 0F; + // compute predictions + float error = (predict(userIndex, positiveItemIndex) - predict(userIndex, negativeItemIndex)) - (positiveScore - negativeScore); + totalError += error * error; + + // update vectors + float value = learnRatio * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float positiveItemFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float negativeItemFactor = itemFactors.getValue(negativeItemIndex, factorIndex); + + userFactors.shiftValue(userIndex, factorIndex, -value * (positiveItemFactor - negativeItemFactor)); + itemFactors.shiftValue(positiveItemIndex, factorIndex, -value * userFactor); + itemFactors.shiftValue(negativeItemIndex, factorIndex, value * userFactor); + } + } + + totalError *= 0.5D; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankVFCDModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankVFCDModel.java new file mode 100644 index 0000000..e586797 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/RankVFCDModel.java @@ -0,0 +1,325 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * + * Rank VFCD推荐器 + * + *
+ * 
+ * + * @author Birdy + * + */ +public class RankVFCDModel extends MatrixFactorizationModel { + + /** + * two low-rank item matrices, an item-item similarity was learned as a product + * of these two matrices + */ + private DenseMatrix userFactors, explicitItemFactors; + private float alpha, beta, gamma, lamutaE; + private SparseMatrix featureMatrix; + private DenseVector featureVector; + private int numberOfFeatures; + private DenseMatrix featureFactors, implicitItemFactors, factorMatrix; + private SparseMatrix relationMatrix; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + // TODO 此处代码可以消除(使用常量Marker代替或者使用binarize.threshold) + for (MatrixScalar term : scoreMatrix) { + term.setValue(1F); + } + + alpha = configuration.getFloat("recommender.rankvfcd.alpha", 5F); + beta = configuration.getFloat("recommender.rankvfcd.beta", 10F); + gamma = configuration.getFloat("recommender.rankvfcd.gamma", 50F); + lamutaE = configuration.getFloat("recommender.rankvfcd.lamutaE", 50F); + numberOfFeatures = 4096; + + userFactors = DenseMatrix.valueOf(userSize, factorSize); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + explicitItemFactors = DenseMatrix.valueOf(itemSize, factorSize); + explicitItemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + implicitItemFactors = DenseMatrix.valueOf(itemSize, factorSize); + implicitItemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + featureFactors = DenseMatrix.valueOf(numberOfFeatures, factorSize); + featureFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + // 相关矩阵 + DataModule relationModel = space.getModule("relation"); + // TODO 此处需要重构,leftDimension与rightDimension要配置 + String leftField = configuration.getString("data.model.fields.left"); + String rightField = configuration.getString("data.model.fields.right"); + String coefficientField = configuration.getString("data.model.fields.coefficient"); + int leftDimension = 0; + int rightDimension = 1; + int coefficientDimension = relationModel.getQuantityInner(coefficientField); + HashMatrix relationTable = new HashMatrix(true, itemSize, itemSize, new Long2FloatRBTreeMap()); + for (DataInstance instance : relationModel) { + int itemIndex = instance.getQualityFeature(leftDimension); + int neighborIndex = instance.getQualityFeature(rightDimension); + relationTable.setValue(itemIndex, neighborIndex, instance.getQuantityFeature(coefficientDimension)); + } + relationMatrix = SparseMatrix.valueOf(itemSize, itemSize, relationTable); + relationTable = null; + + // 特征矩阵 + float minimumValue = Float.MAX_VALUE; + float maximumValue = Float.MIN_VALUE; + HashMatrix visualTable = new HashMatrix(true, numberOfFeatures, itemSize, new Long2FloatRBTreeMap()); + DataModule featureModel = space.getModule("article"); + String articleField = configuration.getString("data.model.fields.article"); + String featureField = configuration.getString("data.model.fields.feature"); + String degreeField = configuration.getString("data.model.fields.degree"); + int articleDimension = featureModel.getQualityInner(articleField); + int featureDimension = featureModel.getQualityInner(featureField); + int degreeDimension = featureModel.getQuantityInner(degreeField); + for (DataInstance instance : featureModel) { + int itemIndex = instance.getQualityFeature(articleDimension); + int featureIndex = instance.getQualityFeature(featureDimension); + float featureValue = instance.getQuantityFeature(degreeDimension); + if (featureValue < minimumValue) { + minimumValue = featureValue; + } + if (featureValue > maximumValue) { + maximumValue = featureValue; + } + visualTable.setValue(featureIndex, itemIndex, featureValue); + } + featureMatrix = SparseMatrix.valueOf(numberOfFeatures, itemSize, visualTable); + visualTable = null; + float maximum = maximumValue; + float minimum = minimumValue; + featureMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((scalar.getValue() - minimum) / (maximum - minimum)); + }); + + factorMatrix = DenseMatrix.valueOf(factorSize, itemSize); + featureVector = DenseVector.valueOf(numberOfFeatures); + for (MatrixScalar term : featureMatrix) { + int featureIndex = term.getRow(); + float value = featureVector.getValue(featureIndex) + term.getValue() * term.getValue(); + featureVector.setValue(featureIndex, value); + } + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + // Init caches + float[] prediction_users = new float[userSize]; + float[] prediction_items = new float[itemSize]; + float[] prediction_itemrelated = new float[itemSize]; + float[] prediction_relateditem = new float[itemSize]; + float[] w_users = new float[userSize]; + float[] w_items = new float[itemSize]; + float[] q_itemrelated = new float[itemSize]; + float[] q_relateditem = new float[itemSize]; + + DenseMatrix explicitItemDeltas = DenseMatrix.valueOf(factorSize, factorSize); + DenseMatrix implicitItemDeltas = DenseMatrix.valueOf(factorSize, factorSize); + DenseMatrix userDeltas = DenseMatrix.valueOf(factorSize, factorSize); + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + // Update the Sq cache + explicitItemDeltas.dotProduct(explicitItemFactors, true, explicitItemFactors, false, MathCalculator.SERIAL); + // Step 1: update user factors; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + prediction_items[itemIndex] = scalar.dotProduct(userFactors.getRowVector(userIndex), explicitItemFactors.getRowVector(itemIndex)).getValue(); + w_items[itemIndex] = 1F + alpha * term.getValue(); + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float numerator = 0F, denominator = userRegularization + explicitItemDeltas.getValue(factorIndex, factorIndex); + // TODO 此处可以改为减法 + for (int k = 0; k < factorSize; k++) { + if (factorIndex != k) { + numerator -= userFactors.getValue(userIndex, k) * explicitItemDeltas.getValue(factorIndex, k); + } + } + float userFactor = userFactors.getValue(userIndex, factorIndex); + for (VectorScalar entry : userVector) { + int i = entry.getIndex(); + float qif = explicitItemFactors.getValue(i, factorIndex); + prediction_items[i] -= userFactor * qif; + numerator += (w_items[i] - (w_items[i] - 1) * prediction_items[i]) * qif; + denominator += (w_items[i] - 1) * qif * qif; + } + // update puf + userFactor = numerator / denominator; + userFactors.setValue(userIndex, factorIndex, userFactor); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + prediction_items[itemIndex] += userFactor * explicitItemFactors.getValue(itemIndex, factorIndex); + } + } + } + + // Update the Sp cache + userDeltas.dotProduct(userFactors, true, userFactors, false, MathCalculator.SERIAL); + implicitItemDeltas.dotProduct(implicitItemFactors, true, implicitItemFactors, false, MathCalculator.SERIAL); + DenseMatrix ETF = factorMatrix; + ETF.dotProduct(featureFactors, true, featureMatrix, false, MathCalculator.PARALLEL); + // Step 2: update item factors; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + SparseVector relationVector = relationMatrix.getRowVector(itemIndex); + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + + prediction_users[userIndex] = scalar.dotProduct(userFactors.getRowVector(userIndex), explicitItemFactors.getRowVector(itemIndex)).getValue(); + w_users[userIndex] = 1F + alpha * term.getValue(); + } + for (VectorScalar term : relationVector) { + int neighborIndex = term.getIndex(); + prediction_itemrelated[neighborIndex] = scalar.dotProduct(explicitItemFactors.getRowVector(itemIndex), implicitItemFactors.getRowVector(neighborIndex)).getValue(); + q_itemrelated[neighborIndex] = 1F + alpha * term.getValue(); + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float explicitNumerator = 0F, explicitDenominator = userDeltas.getValue(factorIndex, factorIndex) + itemRegularization; + float implicitNumerator = 0F, implicitDenominator = implicitItemDeltas.getValue(factorIndex, factorIndex); + // TODO 此处可以改为减法 + for (int k = 0; k < factorSize; k++) { + if (factorIndex != k) { + explicitNumerator -= explicitItemFactors.getValue(itemIndex, k) * userDeltas.getValue(k, factorIndex); + implicitNumerator -= explicitItemFactors.getValue(itemIndex, k) * implicitItemDeltas.getValue(k, factorIndex); + } + } + float explicitItemFactor = explicitItemFactors.getValue(itemIndex, factorIndex); + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + float userFactor = userFactors.getValue(userIndex, factorIndex); + prediction_users[userIndex] -= userFactor * explicitItemFactor; + explicitNumerator += (w_users[userIndex] - (w_users[userIndex] - 1) * prediction_users[userIndex]) * userFactor; + explicitDenominator += (w_users[userIndex] - 1) * userFactor * userFactor; + } + for (VectorScalar term : relationVector) { + int neighborIndex = term.getIndex(); + float implicitItemFactor = implicitItemFactors.getValue(neighborIndex, factorIndex); + prediction_itemrelated[neighborIndex] -= implicitItemFactor * explicitItemFactor; + implicitNumerator += (q_itemrelated[neighborIndex] - (q_itemrelated[neighborIndex] - 1) * prediction_itemrelated[neighborIndex]) * implicitItemFactor; + implicitDenominator += (q_itemrelated[neighborIndex] - 1) * implicitItemFactor * implicitItemFactor; + } + // update qif + explicitItemFactor = (explicitNumerator + implicitNumerator * beta + gamma * ETF.getValue(factorIndex, itemIndex)) / (explicitDenominator + implicitDenominator * beta + gamma); + explicitItemFactors.setValue(itemIndex, factorIndex, explicitItemFactor); + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + prediction_users[userIndex] += userFactors.getValue(userIndex, factorIndex) * explicitItemFactor; + } + for (VectorScalar term : relationVector) { + int neighborIndex = term.getIndex(); + prediction_itemrelated[neighborIndex] += implicitItemFactors.getValue(neighborIndex, factorIndex) * explicitItemFactor; + } + } + } + + explicitItemDeltas.dotProduct(explicitItemFactors, true, explicitItemFactors, false, MathCalculator.SERIAL); + // Step 1: update Z factors; + for (int neighborIndex = 0; neighborIndex < itemSize; neighborIndex++) { + SparseVector relationVector = relationMatrix.getColumnVector(neighborIndex); + for (VectorScalar term : relationVector) { + int itemIndex = term.getIndex(); + prediction_relateditem[itemIndex] = scalar.dotProduct(explicitItemFactors.getRowVector(itemIndex), implicitItemFactors.getRowVector(neighborIndex)).getValue(); + q_relateditem[itemIndex] = 1F + alpha * term.getValue(); + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float numerator = 0F, denominator = explicitItemDeltas.getValue(factorIndex, factorIndex); + // TODO 此处可以改为减法 + for (int k = 0; k < factorSize; k++) { + if (factorIndex != k) { + numerator -= implicitItemFactors.getValue(neighborIndex, k) * explicitItemDeltas.getValue(factorIndex, k); + } + } + float implicitItemFactor = implicitItemFactors.getValue(neighborIndex, factorIndex); + for (VectorScalar term : relationVector) { + int itemIndex = term.getIndex(); + float explicitItemFactor = explicitItemFactors.getValue(itemIndex, factorIndex); + prediction_relateditem[itemIndex] -= implicitItemFactor * explicitItemFactor; + numerator += (q_relateditem[itemIndex] - (q_relateditem[itemIndex] - 1) * prediction_relateditem[itemIndex]) * explicitItemFactor; + denominator += (q_relateditem[itemIndex] - 1) * explicitItemFactor * explicitItemFactor; + } + // update puf + implicitItemFactor = beta * numerator / (beta * denominator + itemRegularization); + implicitItemFactors.setValue(neighborIndex, factorIndex, implicitItemFactor); + for (VectorScalar term : relationVector) { + int itemIndex = term.getIndex(); + prediction_relateditem[itemIndex] += implicitItemFactor * explicitItemFactors.getValue(itemIndex, factorIndex); + } + } + } + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + for (int featureIndex = 0; featureIndex < numberOfFeatures; featureIndex++) { + SparseVector featureVector = featureMatrix.getRowVector(featureIndex); + float numerator = 0F, denominator = featureFactors.getValue(featureIndex, factorIndex); + for (VectorScalar term : featureVector) { + float featureValue = term.getValue(); + int itemIndex = term.getIndex(); + ETF.setValue(factorIndex, itemIndex, ETF.getValue(factorIndex, itemIndex) - denominator * featureValue); + numerator += (explicitItemFactors.getValue(itemIndex, factorIndex) - ETF.getValue(factorIndex, itemIndex)) * featureValue; + } + denominator = numerator * gamma / (gamma * this.featureVector.getValue(featureIndex) + lamutaE); + featureFactors.setValue(featureIndex, factorIndex, denominator); + for (VectorScalar term : featureVector) { + float featureValue = term.getValue(); + int itemIndex = term.getIndex(); + ETF.setValue(factorIndex, itemIndex, ETF.getValue(factorIndex, itemIndex) + denominator * featureValue); + } + } + } + + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + // TODO 目前没有totalLoss. + } + factorMatrix.dotProduct(featureFactors, true, featureMatrix, false, MathCalculator.PARALLEL); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + float score = 0F; + if (scoreMatrix.getColumnVector(itemIndex).getElementSize() == 0) { + score = scalar.dotProduct(userFactors.getRowVector(userIndex), factorMatrix.getColumnVector(itemIndex)).getValue(); + } else { + score = scalar.dotProduct(userFactors.getRowVector(userIndex), explicitItemFactors.getRowVector(itemIndex)).getValue(); + } + instance.setQuantityMark(score); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/SLIMModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/SLIMModel.java new file mode 100644 index 0000000..da95169 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/SLIMModel.java @@ -0,0 +1,301 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.Iterator; +import java.util.TreeSet; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.algorithm.correlation.MathCorrelation; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SymmetryMatrix; +import com.jstarcraft.ai.math.structure.vector.ArrayVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.common.reflection.ReflectionUtility; +import com.jstarcraft.core.utility.Integer2FloatKeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.EpocheModel; +import com.jstarcraft.rns.model.exception.ModelException; + +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; + +/** + * + * SLIM推荐器 + * + *
+ * SLIM: Sparse Linear Methods for Top-N Recommender Systems
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class SLIMModel extends EpocheModel { + + /** + * W in original paper, a sparse matrix of aggregation coefficients + */ + // TODO 考虑修改为对称矩阵? + private DenseMatrix coefficientMatrix; + + /** + * item's nearest neighbors for kNN > 0 + */ + private int[][] itemNeighbors; + + /** + * regularization parameters for the L1 or L2 term + */ + private float regL1Norm, regL2Norm; + + /** + * number of nearest neighbors + */ + private int neighborSize; + + /** + * item similarity matrix + */ + private SymmetryMatrix symmetryMatrix; + + private ArrayVector[] userVectors; + + private ArrayVector[] itemVectors; + + private Comparator comparator = new Comparator() { + + @Override + public int compare(Integer2FloatKeyValue left, Integer2FloatKeyValue right) { + int compare = -(Float.compare(left.getValue(), right.getValue())); + if (compare == 0) { + compare = Integer.compare(left.getKey(), right.getKey()); + } + return compare; + } + + }; + + /** + * initialization + * + * @throws ModelException if error occurs + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + neighborSize = configuration.getInteger("recommender.neighbors.knn.number", 50); + regL1Norm = configuration.getFloat("recommender.slim.regularization.l1", 1.0F); + regL2Norm = configuration.getFloat("recommender.slim.regularization.l2", 1.0F); + + // TODO 考虑重构 + coefficientMatrix = DenseMatrix.valueOf(itemSize, itemSize); + coefficientMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + coefficientMatrix.setValue(itemIndex, itemIndex, 0F); + } + + // initial guesses: make smaller guesses (e.g., W.init(0.01)) to speed + // up training + // TODO 修改为配置枚举 + try { + Class correlationClass = (Class) Class.forName(configuration.getString("recommender.correlation.class")); + MathCorrelation correlation = ReflectionUtility.getInstance(correlationClass); + symmetryMatrix = new SymmetryMatrix(scoreMatrix.getColumnSize()); + correlation.calculateCoefficients(scoreMatrix, true, symmetryMatrix::setValue); + } catch (Exception exception) { + throw new RuntimeException(exception); + } + + // TODO 设置容量 + itemNeighbors = new int[itemSize][]; + Int2ObjectMap> itemNNs = new Int2ObjectOpenHashMap<>(); + for (MatrixScalar term : symmetryMatrix) { + int row = term.getRow(); + int column = term.getColumn(); + if (row == column) { + continue; + } + float value = term.getValue(); + // 忽略相似度为0的物品 + if (value == 0F) { + continue; + } + TreeSet neighbors = itemNNs.get(row); + if (neighbors == null) { + neighbors = new TreeSet<>(comparator); + itemNNs.put(row, neighbors); + } + neighbors.add(new Integer2FloatKeyValue(column, value)); + neighbors = itemNNs.get(column); + if (neighbors == null) { + neighbors = new TreeSet<>(comparator); + itemNNs.put(column, neighbors); + } + neighbors.add(new Integer2FloatKeyValue(row, value)); + } + + // 构建物品邻居映射 + for (Int2ObjectMap.Entry> term : itemNNs.int2ObjectEntrySet()) { + TreeSet neighbors = term.getValue(); + int[] value = new int[neighbors.size() < neighborSize ? neighbors.size() : neighborSize]; + int index = 0; + for (Integer2FloatKeyValue neighbor : neighbors) { + value[index++] = neighbor.getKey(); + if (index >= neighborSize) { + break; + } + } + Arrays.sort(value); + itemNeighbors[term.getIntKey()] = value; + } + + userVectors = new ArrayVector[userSize]; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + userVectors[userIndex] = new ArrayVector(scoreMatrix.getRowVector(userIndex)); + } + + itemVectors = new ArrayVector[itemSize]; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + itemVectors[itemIndex] = new ArrayVector(scoreMatrix.getColumnVector(itemIndex)); + } + } + + /** + * train model + * + * @throws ModelException if error occurs + */ + @Override + protected void doPractice() { + float[] scores = new float[userSize]; + // number of iteration cycles + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + // each cycle iterates through one coordinate direction + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + int[] neighborIndexes = itemNeighbors[itemIndex]; + if (neighborIndexes == null) { + continue; + } + ArrayVector itemVector = itemVectors[itemIndex]; + for (VectorScalar term : itemVector) { + scores[term.getIndex()] = term.getValue(); + } + // for each nearest neighbor nearestNeighborItemIdx, update + // coefficienMatrix by the coordinate + // descent update rule + for (int neighborIndex : neighborIndexes) { + itemVector = itemVectors[neighborIndex]; + float valueSum = 0F, rateSum = 0F, errorSum = 0F; + int count = itemVector.getElementSize(); + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + float neighborScore = term.getValue(); + float userScore = scores[userIndex]; + float error = userScore - predict(userIndex, itemIndex, neighborIndexes, neighborIndex); + valueSum += neighborScore * error; + rateSum += neighborScore * neighborScore; + errorSum += error * error; + } + valueSum /= count; + rateSum /= count; + errorSum /= count; + // TODO 此处考虑重构 + float coefficient = coefficientMatrix.getValue(neighborIndex, itemIndex); + totalError += errorSum + 0.5F * regL2Norm * coefficient * coefficient + regL1Norm * coefficient; + if (regL1Norm < Math.abs(valueSum)) { + if (valueSum > 0) { + coefficient = (valueSum - regL1Norm) / (regL2Norm + rateSum); + } else { + // One doubt: in this case, wij<0, however, the + // paper says wij>=0. How to gaurantee that? + coefficient = (valueSum + regL1Norm) / (regL2Norm + rateSum); + } + } else { + coefficient = 0F; + } + coefficientMatrix.setValue(neighborIndex, itemIndex, coefficient); + } + itemVector = itemVectors[itemIndex]; + for (VectorScalar term : itemVector) { + scores[term.getIndex()] = 0F; + } + } + + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + } + + /** + * predict a specific ranking score for user userIdx on item itemIdx. + * + * @param userIndex user index + * @param itemIndex item index + * @param excludIndex excluded item index + * @return a prediction without the contribution of excluded item + */ + private float predict(int userIndex, int itemIndex, int[] neighbors, int currentIndex) { + float value = 0F; + ArrayVector userVector = userVectors[userIndex]; + if (userVector.getElementSize() == 0) { + return value; + } + int leftCursor = 0, rightCursor = 0, leftSize = userVector.getElementSize(), rightSize = neighbors.length; + Iterator iterator = userVector.iterator(); + VectorScalar term = iterator.next(); + // 判断两个有序数组中是否存在相同的数字 + while (leftCursor < leftSize && rightCursor < rightSize) { + if (term.getIndex() == neighbors[rightCursor]) { + if (neighbors[rightCursor] != currentIndex) { + value += term.getValue() * coefficientMatrix.getValue(neighbors[rightCursor], itemIndex); + } + if (iterator.hasNext()) { + term = iterator.next(); + } + leftCursor++; + rightCursor++; + } else if (term.getIndex() > neighbors[rightCursor]) { + rightCursor++; + } else if (term.getIndex() < neighbors[rightCursor]) { + if (iterator.hasNext()) { + term = iterator.next(); + } + leftCursor++; + } + } + return value; + } + + /** + * predict a specific ranking score for user userIdx on item itemIdx. + * + * @param userIndex user index + * @param itemIndex item index + * @return predictive ranking score for user userIdx on item itemIdx + * @throws ModelException if error occurs + */ + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + int[] neighbors = itemNeighbors[itemIndex]; + if (neighbors == null) { + instance.setQuantityMark(0F); + return; + } + instance.setQuantityMark(predict(userIndex, itemIndex, neighbors, -1)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/UserKNNRankingModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/UserKNNRankingModel.java new file mode 100644 index 0000000..b0063e5 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/UserKNNRankingModel.java @@ -0,0 +1,76 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Iterator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.rns.model.collaborative.UserKNNModel; + +/** + * + * User KNN推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class UserKNNRankingModel extends UserKNNModel { + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + SparseVector itemVector = itemVectors[itemIndex]; + MathVector neighbors = userNeighbors[userIndex]; + if (itemVector.getElementSize() == 0 || neighbors.getElementSize() == 0) { + instance.setQuantityMark(0F); + return; + } + + float sum = 0F, absolute = 0F; + int count = 0; + int leftCursor = 0, rightCursor = 0, leftSize = itemVector.getElementSize(), rightSize = neighbors.getElementSize(); + Iterator leftIterator = itemVector.iterator(); + VectorScalar leftTerm = leftIterator.next(); + Iterator rightIterator = neighbors.iterator(); + VectorScalar rightTerm = rightIterator.next(); + // 判断两个有序数组中是否存在相同的数字 + while (leftCursor < leftSize && rightCursor < rightSize) { + if (leftTerm.getIndex() == rightTerm.getIndex()) { + count++; + sum += rightTerm.getValue(); + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + leftCursor++; + rightCursor++; + } else if (leftTerm.getIndex() > rightTerm.getIndex()) { + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + rightCursor++; + } else if (leftTerm.getIndex() < rightTerm.getIndex()) { + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + leftCursor++; + } + } + + if (count == 0) { + instance.setQuantityMark(0F); + return; + } + + instance.setQuantityMark(sum); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/VBPRModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/VBPRModel.java new file mode 100644 index 0000000..2515ef9 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/VBPRModel.java @@ -0,0 +1,270 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.ArrayVector; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * + * VBPR推荐器 + * + *
+ * VBPR: Visual Bayesian Personalized Randking from Implicit Feedback
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class VBPRModel extends MatrixFactorizationModel { + + /** + * items biases + */ + private DenseVector itemBiases; + + private float biasRegularization; + + private double featureRegularization; + + private int numberOfFeatures; + private DenseMatrix userFeatures; + private DenseVector itemFeatures; + private DenseMatrix featureFactors; + + private HashMatrix featureTable; + private DenseMatrix factorMatrix; + private DenseVector featureVector; + + /** 采样比例 */ + private int sampleRatio; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + // TODO 此处代码可以消除(使用常量Marker代替或者使用binarize.threshold) + for (MatrixScalar term : scoreMatrix) { + term.setValue(1F); + } + + biasRegularization = configuration.getFloat("recommender.bias.regularization", 0.1F); + // TODO 此处应该修改为配置或者动态计算. + numberOfFeatures = 4096; + featureRegularization = 1000; + sampleRatio = configuration.getInteger("recommender.vbpr.alpha", 5); + + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + itemFeatures = DenseVector.valueOf(numberOfFeatures); + itemFeatures.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + userFeatures = DenseMatrix.valueOf(userSize, factorSize); + userFeatures.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + featureFactors = DenseMatrix.valueOf(factorSize, numberOfFeatures); + featureFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + float minimumValue = Float.MAX_VALUE; + float maximumValue = Float.MIN_VALUE; + featureTable = new HashMatrix(true, itemSize, numberOfFeatures, new Long2FloatRBTreeMap()); + DataModule featureModel = space.getModule("article"); + String articleField = configuration.getString("data.model.fields.article"); + String featureField = configuration.getString("data.model.fields.feature"); + String degreeField = configuration.getString("data.model.fields.degree"); + int articleDimension = featureModel.getQualityInner(articleField); + int featureDimension = featureModel.getQualityInner(featureField); + int degreeDimension = featureModel.getQuantityInner(degreeField); + for (DataInstance instance : featureModel) { + int itemIndex = instance.getQualityFeature(articleDimension); + int featureIndex = instance.getQualityFeature(featureDimension); + float featureValue = instance.getQuantityFeature(degreeDimension); + if (featureValue < minimumValue) { + minimumValue = featureValue; + } + if (featureValue > maximumValue) { + maximumValue = featureValue; + } + featureTable.setValue(itemIndex, featureIndex, featureValue); + } + for (MatrixScalar cell : featureTable) { + float value = (cell.getValue() - minimumValue) / (maximumValue - minimumValue); + featureTable.setValue(cell.getRow(), cell.getColumn(), value); + } + factorMatrix = DenseMatrix.valueOf(factorSize, itemSize); + factorMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector factorVector = DenseVector.valueOf(featureFactors.getRowSize()); + ArrayVector[] featureVectors = new ArrayVector[itemSize]; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + MathVector keyValues = featureTable.getRowVector(itemIndex); + int[] featureIndexes = new int[keyValues.getElementSize()]; + float[] featureValues = new float[keyValues.getElementSize()]; + int position = 0; + for (VectorScalar keyValue : keyValues) { + featureIndexes[position] = keyValue.getIndex(); + featureValues[position] = keyValue.getValue(); + position++; + } + featureVectors[itemIndex] = new ArrayVector(numberOfFeatures, featureIndexes, featureValues); + } + float[] featureValues = new float[numberOfFeatures]; + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int sampleIndex = 0, numberOfSamples = userSize * sampleRatio; sampleIndex < numberOfSamples; sampleIndex++) { + // randomly draw (u, i, j) + int userKey, positiveItemKey, negativeItemKey; + while (true) { + userKey = RandomUtility.randomInteger(userSize); + SparseVector userVector = scoreMatrix.getRowVector(userKey); + if (userVector.getElementSize() == 0) { + continue; + } + positiveItemKey = userVector.getIndex(RandomUtility.randomInteger(userVector.getElementSize())); + negativeItemKey = RandomUtility.randomInteger(itemSize - userVector.getElementSize()); + for (VectorScalar term : userVector) { + if (negativeItemKey >= term.getIndex()) { + negativeItemKey++; + } else { + break; + } + } + break; + } + int userIndex = userKey, positiveItemIndex = positiveItemKey, negativeItemIndex = negativeItemKey; + ArrayVector positiveItemVector = featureVectors[positiveItemIndex]; + ArrayVector negativeItemVector = featureVectors[negativeItemIndex]; + // update parameters + float positiveScore = predict(userIndex, positiveItemIndex, scalar.dotProduct(itemFeatures, positiveItemVector).getValue(), factorVector.dotProduct(featureFactors, false, positiveItemVector, MathCalculator.SERIAL)); + float negativeScore = predict(userIndex, negativeItemIndex, scalar.dotProduct(itemFeatures, negativeItemVector).getValue(), factorVector.dotProduct(featureFactors, false, negativeItemVector, MathCalculator.SERIAL)); + float error = LogisticUtility.getValue(positiveScore - negativeScore); + totalError += (float) -Math.log(error); + // update bias + float positiveBias = itemBiases.getValue(positiveItemIndex), negativeBias = itemBiases.getValue(negativeItemIndex); + itemBiases.shiftValue(positiveItemIndex, learnRatio * (error - biasRegularization * positiveBias)); + itemBiases.shiftValue(negativeItemIndex, learnRatio * (-error - biasRegularization * negativeBias)); + totalError += biasRegularization * positiveBias * positiveBias + biasRegularization * negativeBias * negativeBias; + for (VectorScalar term : positiveItemVector) { + featureValues[term.getIndex()] = term.getValue(); + } + for (VectorScalar term : negativeItemVector) { + featureValues[term.getIndex()] -= term.getValue(); + } + // update user/item vectors + // 按照因子切割任务实现并发计算. + // CountDownLatch factorLatch = new + // CountDownLatch(numberOfFactors); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float positiveItemFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float negativeItemFactor = itemFactors.getValue(negativeItemIndex, factorIndex); + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (error * (positiveItemFactor - negativeItemFactor) - userRegularization * userFactor)); + itemFactors.shiftValue(positiveItemIndex, factorIndex, learnRatio * (error * (userFactor) - itemRegularization * positiveItemFactor)); + itemFactors.shiftValue(negativeItemIndex, factorIndex, learnRatio * (error * (-userFactor) - itemRegularization * negativeItemFactor)); + totalError += userRegularization * userFactor * userFactor + itemRegularization * positiveItemFactor * positiveItemFactor + itemRegularization * negativeItemFactor * negativeItemFactor; + + float userFeature = userFeatures.getValue(userIndex, factorIndex); + DenseVector featureVector = featureFactors.getRowVector(factorIndex); + userFeatures.shiftValue(userIndex, factorIndex, learnRatio * (error * (scalar.dotProduct(featureVector, positiveItemVector).getValue() - scalar.dotProduct(featureVector, negativeItemVector).getValue()) - userRegularization * userFeature)); + totalError += userRegularization * userFeature * userFeature; + featureVector.iterateElement(MathCalculator.SERIAL, (element) -> { + int index = element.getIndex(); + float value = element.getValue(); + totalError += featureRegularization * value * value; + value += learnRatio * (error * userFeature * featureValues[index] - featureRegularization * value); + element.setValue(value); + }); + } + // 按照特征切割任务实现并发计算. + itemFeatures.iterateElement(MathCalculator.SERIAL, (element) -> { + int index = element.getIndex(); + float value = element.getValue(); + totalError += featureRegularization * value * value; + value += learnRatio * (featureValues[index] - featureRegularization * value); + element.setValue(value); + }); + // try { + // factorLatch.await(); + // } catch (Exception exception) { + // throw new LibrecException(exception); + // } + for (VectorScalar term : positiveItemVector) { + featureValues[term.getIndex()] = 0F; + } + for (VectorScalar term : negativeItemVector) { + featureValues[term.getIndex()] -= 0F; + } + } + + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + + factorMatrix.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + ArrayVector vector = featureVectors[column]; + float value = 0; + for (VectorScalar entry : vector) { + value += featureFactors.getValue(row, entry.getIndex()) * entry.getValue(); + } + element.setValue(value); + }); + featureVector = DenseVector.valueOf(itemSize); + featureVector.iterateElement(MathCalculator.SERIAL, (element) -> { + element.dotProduct(itemFeatures, featureVectors[element.getIndex()]).getValue(); + }); + } + + private float predict(int userIndex, int itemIndex, float itemFeature, MathVector factorVector) { + DefaultScalar scalar = DefaultScalar.getInstance(); + scalar.setValue(0F); + scalar.shiftValue(itemBiases.getValue(itemIndex) + itemFeature); + scalar.accumulateProduct(userFactors.getRowVector(userIndex), itemFactors.getRowVector(itemIndex)); + scalar.accumulateProduct(userFeatures.getRowVector(userIndex), factorVector); + return scalar.getValue(); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex, featureVector.getValue(itemIndex), factorMatrix.getColumnVector(itemIndex))); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/WARPMFModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/WARPMFModel.java new file mode 100644 index 0000000..0e35427 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/WARPMFModel.java @@ -0,0 +1,112 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +/** + * + * WARP推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class WARPMFModel extends MatrixFactorizationModel { + + private int lossType; + + private float epsilon; + + private float[] orderLosses; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + lossType = configuration.getInteger("losstype", 3); + epsilon = configuration.getFloat("epsilon"); + orderLosses = new float[itemSize - 1]; + float orderLoss = 0F; + for (int orderIndex = 1; orderIndex < itemSize; orderIndex++) { + orderLoss += 1D / orderIndex; + orderLosses[orderIndex - 1] = orderLoss; + } + for (int rankIndex = 1; rankIndex < itemSize; rankIndex++) { + orderLosses[rankIndex - 1] /= orderLoss; + } + } + + @Override + protected void doPractice() { + int Y, N; + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int sampleIndex = 0, sampleTimes = userSize * 100; sampleIndex < sampleTimes; sampleIndex++) { + int userIndex, positiveItemIndex, negativeItemIndex; + float positiveScore; + float negativeScore; + while (true) { + userIndex = RandomUtility.randomInteger(userSize); + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0 || userVector.getElementSize() == itemSize) { + continue; + } + + N = 0; + Y = itemSize - scoreMatrix.getRowScope(userIndex); + positiveItemIndex = userVector.getIndex(RandomUtility.randomInteger(userVector.getElementSize())); + positiveScore = predict(userIndex, positiveItemIndex); + do { + N++; + negativeItemIndex = RandomUtility.randomInteger(itemSize - userVector.getElementSize()); + for (int index = 0, size = userVector.getElementSize(); index < size; index++) { + if (negativeItemIndex >= userVector.getIndex(index)) { + negativeItemIndex++; + continue; + } + break; + } + negativeScore = predict(userIndex, negativeItemIndex); + } while ((positiveScore - negativeScore > epsilon) && N < Y - 1); + break; + } + // update parameters + float error = positiveScore - negativeScore; + + float gradient = calaculateGradientValue(lossType, error); + int orderIndex = (int) ((Y - 1) / N); + float orderLoss = orderLosses[orderIndex]; + gradient = gradient * orderLoss; + + totalError += -Math.log(LogisticUtility.getValue(error)); + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float positiveFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float negativeFactor = itemFactors.getValue(negativeItemIndex, factorIndex); + + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (gradient * (positiveFactor - negativeFactor) - userRegularization * userFactor)); + itemFactors.shiftValue(positiveItemIndex, factorIndex, learnRatio * (gradient * userFactor - itemRegularization * positiveFactor)); + itemFactors.shiftValue(negativeItemIndex, factorIndex, learnRatio * (gradient * (-userFactor) - itemRegularization * negativeFactor)); + totalError += userRegularization * userFactor * userFactor + itemRegularization * positiveFactor * positiveFactor + itemRegularization * negativeFactor * negativeFactor; + } + } + + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/WBPRModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/WBPRModel.java new file mode 100644 index 0000000..661cbbb --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/WBPRModel.java @@ -0,0 +1,182 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +import it.unimi.dsi.fastutil.ints.IntSet; + +/** + * + * WBPR推荐器 + * + *
+ * Bayesian Personalized Ranking for Non-Uniformly Sampled Items
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class WBPRModel extends MatrixFactorizationModel { + /** + * user items Set + */ + // private LoadingCache userItemsSet; + + /** + * pre-compute and sort by item's popularity + */ + private List> itemPopularities; + + private List>[] itemProbabilities; + + /** + * items biases + */ + private DenseVector itemBiases; + + /** + * bias regularization + */ + private float biasRegularization; + + /** + * Guava cache configuration + */ + // protected static String cacheSpec; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + biasRegularization = configuration.getFloat("recommender.bias.regularization", 0.01F); + + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(0.01F)); + }); + + // pre-compute and sort by item's popularity + itemPopularities = new ArrayList<>(itemSize); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + itemPopularities.add(new KeyValue<>(itemIndex, Double.valueOf(scoreMatrix.getColumnScope(itemIndex)))); + } + Collections.sort(itemPopularities, (left, right) -> { + // 降序 + return right.getValue().compareTo(left.getValue()); + }); + + itemProbabilities = new List[userSize]; + List userItemSet = getUserItemSet(scoreMatrix); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + IntSet scoreSet = userItemSet.get(userIndex); + List> probabilities = new LinkedList<>(); + itemProbabilities[userIndex] = probabilities; + // filter candidate items + double sum = 0; + for (KeyValue term : itemPopularities) { + int itemIndex = term.getKey(); + double popularity = term.getValue(); + if (!scoreSet.contains(itemIndex) && popularity > 0D) { + // make a clone to prevent bugs from normalization + probabilities.add(term); + sum += popularity; + } + } + // normalization + for (KeyValue term : probabilities) { + term.setValue(term.getValue() / sum); + } + } + } + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int sampleIndex = 0, sampleTimes = userSize * 100; sampleIndex < sampleTimes; sampleIndex++) { + // randomly draw (userIdx, posItemIdx, negItemIdx) + int userIndex, positiveItemIndex, negativeItemIndex = 0; + List> probabilities; + while (true) { + userIndex = RandomUtility.randomInteger(userSize); + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0) { + continue; + } + positiveItemIndex = userVector.getIndex(RandomUtility.randomInteger(userVector.getElementSize())); + // sample j by popularity (probability) + probabilities = itemProbabilities[userIndex]; + double random = RandomUtility.randomDouble(1D); + for (KeyValue term : probabilities) { + if ((random -= term.getValue()) <= 0D) { + negativeItemIndex = term.getKey(); + break; + } + } + break; + } + + // update parameters + float positiveScore = predict(userIndex, positiveItemIndex); + float negativeScore = predict(userIndex, negativeItemIndex); + float error = positiveScore - negativeScore; + float value = (float) -Math.log(LogisticUtility.getValue(error)); + totalError += value; + value = LogisticUtility.getValue(-error); + + // update bias + float positiveBias = itemBiases.getValue(positiveItemIndex), negativeBias = itemBiases.getValue(negativeItemIndex); + itemBiases.shiftValue(positiveItemIndex, learnRatio * (value - biasRegularization * positiveBias)); + itemBiases.shiftValue(negativeItemIndex, learnRatio * (-value - biasRegularization * negativeBias)); + totalError += biasRegularization * (positiveBias * positiveBias + negativeBias * negativeBias); + + // update user/item vectors + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float positiveItemFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float negativeItemFactor = itemFactors.getValue(negativeItemIndex, factorIndex); + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (value * (positiveItemFactor - negativeItemFactor) - userRegularization * userFactor)); + itemFactors.shiftValue(positiveItemIndex, factorIndex, learnRatio * (value * userFactor - itemRegularization * positiveItemFactor)); + itemFactors.shiftValue(negativeItemIndex, factorIndex, learnRatio * (value * (-userFactor) - itemRegularization * negativeItemFactor)); + totalError += userRegularization * userFactor * userFactor + itemRegularization * positiveItemFactor * positiveItemFactor + itemRegularization * negativeItemFactor * negativeItemFactor; + } + } + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + @Override + protected float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userFactors.getRowVector(userIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + return itemBiases.getValue(itemIndex) + scalar.dotProduct(userVector, itemVector).getValue(); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/WRMFModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/WRMFModel.java new file mode 100644 index 0000000..37daa68 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/ranking/WRMFModel.java @@ -0,0 +1,202 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Date; +import java.util.concurrent.CountDownLatch; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.environment.EnvironmentContext; +import com.jstarcraft.ai.math.MatrixUtility; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.model.exception.ModelException; + +/** + * + * WRMF推荐器 + * + *
+ * WRMF: Weighted Regularized Matrix Factorization
+ * Collaborative filtering for implicit feedback datasets
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class WRMFModel extends MatrixFactorizationModel { + /** + * confidence weight coefficient + */ + private float weightCoefficient; + + /** + * confindence Minus Identity Matrix{ui} = confidenceMatrix_{ui} - 1 =alpha * + * r_{ui} or log(1+10^alpha * r_{ui}) + */ + // TODO 应该重构为SparseMatrix + private SparseMatrix confindenceMatrix; + + /** + * preferenceMatrix_{ui} = 1 if {@code r_{ui}>0 or preferenceMatrix_{ui} = 0} + */ + // TODO 应该重构为SparseMatrix + private SparseMatrix preferenceMatrix; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + weightCoefficient = configuration.getFloat("recommender.wrmf.weight.coefficient", 4.0f); + + confindenceMatrix = SparseMatrix.copyOf(scoreMatrix, false); + confindenceMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue((float) Math.log(1F + Math.pow(10, weightCoefficient) * scalar.getValue())); + }); + preferenceMatrix = SparseMatrix.copyOf(scoreMatrix, false); + preferenceMatrix.setValues(1F); + } + + private ThreadLocal factorMatrixStorage = new ThreadLocal<>(); + private ThreadLocal copyMatrixStorage = new ThreadLocal<>(); + private ThreadLocal inverseMatrixStorage = new ThreadLocal<>(); + + @Override + protected void constructEnvironment() { + // 缓存特征计算,避免消耗内存 + factorMatrixStorage.set(DenseMatrix.valueOf(factorSize, factorSize)); + copyMatrixStorage.set(DenseMatrix.valueOf(factorSize, factorSize)); + inverseMatrixStorage.set(DenseMatrix.valueOf(factorSize, factorSize)); + } + + @Override + protected void destructEnvironment() { + factorMatrixStorage.remove(); + copyMatrixStorage.remove(); + inverseMatrixStorage.remove(); + } + + @Override + protected void doPractice() { + EnvironmentContext context = EnvironmentContext.getContext(); + // 缓存特征计算,避免消耗内存 + DenseMatrix transposeMatrix = DenseMatrix.valueOf(factorSize, factorSize); + + // To be consistent with the symbols in the paper + // Updating by using alternative least square (ALS) + // due to large amount of entries to be processed (SGD will be too slow) + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + // Step 1: update user factors; + // 按照用户切割任务实现并发计算. + DenseMatrix itemSymmetryMatrix = transposeMatrix; + itemSymmetryMatrix.dotProduct(itemFactors, true, itemFactors, false, MathCalculator.SERIAL); + CountDownLatch userLatch = new CountDownLatch(userSize); + for (int index = 0; index < userSize; index++) { + int userIndex = index; + context.doAlgorithmByAny(index, () -> { + DenseMatrix factorMatrix = factorMatrixStorage.get(); + DenseMatrix copyMatrix = copyMatrixStorage.get(); + DenseMatrix inverseMatrix = inverseMatrixStorage.get(); + SparseVector confindenceVector = confindenceMatrix.getRowVector(userIndex); + // YtY + Yt * (Cu - itemIdx) * Y + factorMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = 0F; + for (VectorScalar term : confindenceVector) { + int itemIndex = term.getIndex(); + value += itemFactors.getValue(itemIndex, row) * term.getValue() * itemFactors.getValue(itemIndex, column); + } + value += itemSymmetryMatrix.getValue(row, column); + value += userRegularization; + scalar.setValue(value); + }); + // (YtCuY + lambda * itemIdx)^-1 + // lambda * itemIdx can be pre-difined because every time is + // the + // same. + // Yt * (Cu - itemIdx) * Pu + Yt * Pu + DenseVector userFactorVector = DenseVector.valueOf(factorSize); + SparseVector preferenceVector = preferenceMatrix.getRowVector(userIndex); + for (int position = 0, size = preferenceVector.getElementSize(); position < size; position++) { + int itemIndex = preferenceVector.getIndex(position); + float confindence = confindenceVector.getValue(position); + float preference = preferenceVector.getValue(position); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + userFactorVector.shiftValue(factorIndex, preference * (itemFactors.getValue(itemIndex, factorIndex) * confindence + itemFactors.getValue(itemIndex, factorIndex))); + } + } + // udpate user factors + userFactors.getRowVector(userIndex).dotProduct(MatrixUtility.inverse(factorMatrix, copyMatrix, inverseMatrix), false, userFactorVector, MathCalculator.SERIAL); + userLatch.countDown(); + }); + } + try { + userLatch.await(); + } catch (Exception exception) { + throw new ModelException(exception); + } + + // Step 2: update item factors; + // 按照物品切割任务实现并发计算. + DenseMatrix userSymmetryMatrix = transposeMatrix; + userSymmetryMatrix.dotProduct(userFactors, true, userFactors, false, MathCalculator.SERIAL); + CountDownLatch itemLatch = new CountDownLatch(itemSize); + for (int index = 0; index < itemSize; index++) { + int itemIndex = index; + context.doAlgorithmByAny(index, () -> { + DenseMatrix factorMatrix = factorMatrixStorage.get(); + DenseMatrix copyMatrix = copyMatrixStorage.get(); + DenseMatrix inverseMatrix = inverseMatrixStorage.get(); + SparseVector confindenceVector = confindenceMatrix.getColumnVector(itemIndex); + // XtX + Xt * (Ci - itemIdx) * X + factorMatrix.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = 0F; + for (VectorScalar term : confindenceVector) { + int userIndex = term.getIndex(); + value += userFactors.getValue(userIndex, row) * term.getValue() * userFactors.getValue(userIndex, column); + } + value += userSymmetryMatrix.getValue(row, column); + value += itemRegularization; + scalar.setValue(value); + }); + // (XtCuX + lambda * itemIdx)^-1 + // lambda * itemIdx can be pre-difined because every time is + // the + // same. + // Xt * (Ci - itemIdx) * Pu + Xt * Pu + DenseVector itemFactorVector = DenseVector.valueOf(factorSize); + SparseVector preferenceVector = preferenceMatrix.getColumnVector(itemIndex); + for (int position = 0, size = preferenceVector.getElementSize(); position < size; position++) { + int userIndex = preferenceVector.getIndex(position); + float confindence = confindenceVector.getValue(position); + float preference = preferenceVector.getValue(position); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + itemFactorVector.shiftValue(factorIndex, preference * (userFactors.getValue(userIndex, factorIndex) * confindence + userFactors.getValue(userIndex, factorIndex))); + } + } + // udpate item factors + itemFactors.getRowVector(itemIndex).dotProduct(MatrixUtility.inverse(factorMatrix, copyMatrix, inverseMatrix), false, itemFactorVector, MathCalculator.SERIAL); + itemLatch.countDown(); + }); + } + try { + itemLatch.await(); + } catch (Exception exception) { + throw new ModelException(exception); + } + + if (logger.isInfoEnabled()) { + logger.info(getClass() + " runs at iteration = " + epocheIndex + " " + new Date()); + } + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/ASVDPlusPlusModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/ASVDPlusPlusModel.java new file mode 100644 index 0000000..3ddfecc --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/ASVDPlusPlusModel.java @@ -0,0 +1,128 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; + +/** + * + * Asymmetric SVD++推荐器 + * + *
+ * Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class ASVDPlusPlusModel extends BiasedMFModel { + + private DenseMatrix positiveFactors, negativeFactors; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + positiveFactors = DenseMatrix.valueOf(itemSize, factorSize); + positiveFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + negativeFactors = DenseMatrix.valueOf(itemSize, factorSize); + negativeFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + } + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + // TODO 目前没有totalLoss. + totalError = 0f; + for (MatrixScalar matrixTerm : scoreMatrix) { + int userIndex = matrixTerm.getRow(); + int itemIndex = matrixTerm.getColumn(); + float score = matrixTerm.getValue(); + float predict = predict(userIndex, itemIndex); + float error = score - predict; + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + + // update factors + float userBiasValue = userBiases.getValue(userIndex); + userBiases.shiftValue(userIndex, learnRatio * (error - regBias * userBiasValue)); + float itemBiasValue = itemBiases.getValue(itemIndex); + itemBiases.shiftValue(itemIndex, learnRatio * (error - regBias * itemBiasValue)); + + float squareRoot = (float) Math.sqrt(userVector.getElementSize()); + float[] positiveSums = new float[factorSize]; + float[] negativeSums = new float[factorSize]; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float positiveSum = 0F; + float negativeSum = 0F; + for (VectorScalar term : userVector) { + int ItemIdx = term.getIndex(); + positiveSum += positiveFactors.getValue(ItemIdx, factorIndex); + negativeSum += negativeFactors.getValue(ItemIdx, factorIndex) * (score - meanScore - userBiases.getValue(userIndex) - itemBiases.getValue(ItemIdx)); + } + positiveSums[factorIndex] = squareRoot > 0 ? positiveSum / squareRoot : positiveSum; + negativeSums[factorIndex] = squareRoot > 0 ? negativeSum / squareRoot : negativeSum; + } + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float itemFactor = itemFactors.getValue(itemIndex, factorIndex); + float userValue = error * itemFactor - userRegularization * userFactor; + float itemValue = error * (userFactor + positiveSums[factorIndex] + negativeSums[factorIndex]) - itemRegularization * itemFactor; + userFactors.shiftValue(userIndex, factorIndex, learnRatio * userValue); + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * itemValue); + for (VectorScalar term : userVector) { + int index = term.getIndex(); + float positiveFactor = positiveFactors.getValue(index, factorIndex); + float negativeFactor = negativeFactors.getValue(index, factorIndex); + float positiveDelta = error * itemFactor / squareRoot - userRegularization * positiveFactor; + float negativeDelta = error * itemFactor * (score - meanScore - userBiases.getValue(userIndex) - itemBiases.getValue(index)) / squareRoot - userRegularization * negativeFactor; + positiveFactors.shiftValue(index, factorIndex, learnRatio * positiveDelta); + negativeFactors.shiftValue(index, factorIndex, learnRatio * negativeDelta); + } + } + } + } + } + + @Override + protected float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userFactors.getRowVector(userIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + float value = meanScore + userBiases.getValue(userIndex) + itemBiases.getValue(itemIndex) + scalar.dotProduct(userVector, itemVector).getValue(); + SparseVector rateVector = scoreMatrix.getRowVector(userIndex); + float squareRoot = (float) Math.sqrt(rateVector.getElementSize()); + for (VectorScalar term : rateVector) { + itemIndex = term.getIndex(); + DenseVector positiveVector = positiveFactors.getRowVector(itemIndex); + DenseVector negativeVector = negativeFactors.getRowVector(itemIndex); + value += scalar.dotProduct(positiveVector, itemVector).getValue() / squareRoot; + float scale = term.getValue() - meanScore - userBiases.getValue(userIndex) - itemBiases.getValue(itemIndex); + value += scalar.dotProduct(negativeVector, itemVector).getValue() * scale / squareRoot; + } + if (Float.isNaN(value)) { + value = meanScore; + } + return value; + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex)); + } + +} \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/AspectModelRatingModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/AspectModelRatingModel.java new file mode 100644 index 0000000..45cbec6 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/AspectModelRatingModel.java @@ -0,0 +1,189 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.MathUtility; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; +import com.jstarcraft.rns.utility.GaussianUtility; + +/** + * + * Aspect Model推荐器 + * + *
+ * Latent class models for collaborative filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class AspectModelRatingModel extends ProbabilisticGraphicalModel { + /* + * Conditional distribution: P(u|z) + */ + private DenseMatrix userProbabilities, userSums; + /* + * Conditional distribution: P(i|z) + */ + private DenseMatrix itemProbabilities, itemSums; + /* + * topic distribution: P(z) + */ + private DenseVector topicProbabilities, topicSums; + /* + * + */ + private DenseVector meanProbabilities, meanSums; + /* + * + */ + private DenseVector varianceProbabilities, varianceSums; + + /* + * small value + */ + private static float smallValue = MathUtility.EPSILON; + /* + * {user, item, {topic z, probability}} + */ + private float[][] probabilityTensor; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // Initialize topic distribution + topicProbabilities = DenseVector.valueOf(factorSize); + topicProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(factorSize) + 1); + }); + topicProbabilities.scaleValues(1F / topicProbabilities.getSum(false)); + topicSums = DenseVector.valueOf(factorSize); + + // intialize conditional distribution P(u|z) + userProbabilities = DenseMatrix.valueOf(factorSize, userSize); + userSums = DenseMatrix.valueOf(factorSize, userSize); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + DenseVector probabilityVector = userProbabilities.getRowVector(topicIndex); + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(userSize) + 1); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + + // initialize conditional distribution P(i|z) + itemProbabilities = DenseMatrix.valueOf(factorSize, itemSize); + itemSums = DenseMatrix.valueOf(factorSize, itemSize); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + DenseVector probabilityVector = itemProbabilities.getRowVector(topicIndex); + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(itemSize) + 1); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + + // initialize Q + probabilityTensor = new float[actionSize][factorSize]; + + float globalMean = scoreMatrix.getSum(false) / scoreMatrix.getElementSize(); + meanProbabilities = DenseVector.valueOf(factorSize); + varianceProbabilities = DenseVector.valueOf(factorSize); + meanSums = DenseVector.valueOf(factorSize); + varianceSums = DenseVector.valueOf(factorSize); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + meanProbabilities.setValue(topicIndex, globalMean); + varianceProbabilities.setValue(topicIndex, 2); + } + } + + @Override + protected void eStep() { + topicSums.setValues(smallValue); + userSums.setValues(0F); + itemSums.setValues(0F); + meanSums.setValues(0F); + varianceSums.setValues(smallValue); + // variational inference to compute Q + int actionIndex = 0; + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float denominator = 0F; + float[] numerator = probabilityTensor[actionIndex++]; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float value = topicProbabilities.getValue(topicIndex) * userProbabilities.getValue(topicIndex, userIndex) * itemProbabilities.getValue(topicIndex, itemIndex) * GaussianUtility.probabilityDensity(term.getValue(), meanProbabilities.getValue(topicIndex), varianceProbabilities.getValue(topicIndex)); + numerator[topicIndex] = value; + denominator += value; + } + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float probability = denominator > 0 ? numerator[topicIndex] / denominator : 0F; + numerator[topicIndex] = probability; + } + + float score = term.getValue(); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float probability = numerator[topicIndex]; + topicSums.shiftValue(topicIndex, probability); + userSums.shiftValue(topicIndex, userIndex, probability); + itemSums.shiftValue(topicIndex, itemIndex, probability); + meanSums.shiftValue(topicIndex, score * probability); + } + } + + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float mean = meanSums.getValue(topicIndex) / topicSums.getValue(topicIndex); + meanProbabilities.setValue(topicIndex, mean); + } + + actionIndex = 0; + for (MatrixScalar term : scoreMatrix) { + float[] probabilities = probabilityTensor[actionIndex++]; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float mean = meanProbabilities.getValue(topicIndex); + float error = term.getValue() - mean; + float probability = probabilities[topicIndex]; + varianceSums.shiftValue(topicIndex, error * error * probability); + } + } + } + + @Override + protected void mStep() { + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + varianceProbabilities.setValue(topicIndex, varianceSums.getValue(topicIndex) / topicSums.getValue(topicIndex)); + topicProbabilities.setValue(topicIndex, topicSums.getValue(topicIndex) / actionSize); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + userProbabilities.setValue(topicIndex, userIndex, userSums.getValue(topicIndex, userIndex) / topicSums.getValue(topicIndex)); + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + itemProbabilities.setValue(topicIndex, itemIndex, itemSums.getValue(topicIndex, itemIndex) / topicSums.getValue(topicIndex)); + } + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = 0F; + float denominator = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float weight = topicProbabilities.getValue(topicIndex) * userProbabilities.getValue(topicIndex, userIndex) * itemProbabilities.getValue(topicIndex, itemIndex); + denominator += weight; + value += weight * meanProbabilities.getValue(topicIndex); + } + value = value / denominator; + if (Float.isNaN(value)) { + value = meanScore; + } + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/AutoRecLearner.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/AutoRecLearner.java new file mode 100644 index 0000000..4d8f62a --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/AutoRecLearner.java @@ -0,0 +1,91 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.linalg.primitives.Pair; + +/** + * + * AutoRec学习器 + * + *
+ * AutoRec: Autoencoders Meet Collaborative Filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class AutoRecLearner implements ILossFunction { + + private INDArray maskData; + + public AutoRecLearner(INDArray maskData) { + this.maskData = maskData; + } + + private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray scoreArr; + INDArray output = activationFn.getActivation(preOutput.dup(), true); + INDArray yMinusyHat = Transforms.abs(labels.sub(output)); + scoreArr = yMinusyHat.mul(yMinusyHat); + scoreArr = scoreArr.mul(maskData); + + if (mask != null) { + scoreArr.muliColumnVector(mask); + } + return scoreArr; + } + + @Override + public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask); + double score = scoreArr.sumNumber().doubleValue(); + + if (average) { + score /= scoreArr.size(0); + } + + return score; + } + + @Override + public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask); + return scoreArr.sum(1); + } + + @Override + public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray output = activationFn.getActivation(preOutput.dup(), true); + INDArray yMinusyHat = labels.sub(output); + INDArray dldyhat = yMinusyHat.mul(-2); + + INDArray gradients = activationFn.backprop(preOutput.dup(), dldyhat).getFirst(); + gradients = gradients.mul(maskData); + // multiply with masks, always + if (mask != null) { + gradients.muliColumnVector(mask); + } + + return gradients; + } + + @Override + public Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + return new Pair<>(computeScore(labels, preOutput, activationFn, mask, average), computeGradient(labels, preOutput, activationFn, mask)); + } + + @Override + public String toString() { + return super.toString() + "AutoRecLossFunction"; + } + + @Override + public String name() { + // TODO Auto-generated method stub + return toString(); + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/AutoRecModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/AutoRecModel.java new file mode 100644 index 0000000..26f2704 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/AutoRecModel.java @@ -0,0 +1,75 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Nesterovs; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.NeuralNetworkModel; + +/** + * + * AutoRec学习器 + * + *
+ * AutoRec: Autoencoders Meet Collaborative Filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class AutoRecModel extends NeuralNetworkModel { + + /** + * the data structure that indicates which element in the user-item is non-zero + */ + private INDArray maskData; + + @Override + protected int getInputDimension() { + return userSize; + } + + @Override + protected MultiLayerConfiguration getNetworkConfiguration() { + NeuralNetConfiguration.ListBuilder factory = new NeuralNetConfiguration.Builder().seed(6).updater(new Nesterovs(learnRatio, momentum)).weightInit(WeightInit.XAVIER_UNIFORM).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(weightRegularization).list(); + factory.layer(0, new DenseLayer.Builder().nIn(inputDimension).nOut(hiddenDimension).activation(Activation.fromString(hiddenActivation)).build()); + factory.layer(1, new OutputLayer.Builder(new AutoRecLearner(maskData)).nIn(hiddenDimension).nOut(inputDimension).activation(Activation.fromString(outputActivation)).build()); + MultiLayerConfiguration configuration = factory.pretrain(false).backprop(true).build(); + return configuration; + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // transform the sparse matrix to INDArray + int[] matrixShape = new int[] { itemSize, userSize }; + inputData = Nd4j.zeros(matrixShape); + maskData = Nd4j.zeros(matrixShape); + for (MatrixScalar term : scoreMatrix) { + if (term.getValue() > 0D) { + inputData.putScalar(term.getColumn(), term.getRow(), term.getValue()); + maskData.putScalar(term.getColumn(), term.getRow(), 1D); + } + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(outputData.getFloat(itemIndex, userIndex)); + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BHFreeRatingModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BHFreeRatingModel.java new file mode 100644 index 0000000..74e2fde --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BHFreeRatingModel.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Map.Entry; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.rns.model.collaborative.BHFreeModel; + +/** + * + * BH Free推荐器 + * + *
+ * Balancing Prediction and Recommendation Accuracy: Hierarchical Latent Factors for Preference Data
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class BHFreeRatingModel extends BHFreeModel { + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = 0F, probabilities = 0F; + for (Entry entry : scoreIndexes.entrySet()) { + float score = entry.getKey(); + float probability = 0F; + for (int userTopic = 0; userTopic < userTopicSize; userTopic++) { + for (int itemTopic = 0; itemTopic < itemTopicSize; itemTopic++) { + probability += user2TopicProbabilities.getValue(userIndex, userTopic) * userTopic2ItemTopicProbabilities.getValue(userTopic, itemTopic) * userTopic2ItemTopicScoreProbabilities[userTopic][itemTopic][entry.getValue()]; + } + } + value += score * probability; + probabilities += probability; + } + instance.setQuantityMark(value / probabilities); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BPMFModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BPMFModel.java new file mode 100644 index 0000000..076f49b --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BPMFModel.java @@ -0,0 +1,300 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import org.apache.commons.math3.distribution.GammaDistribution; +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.random.JDKRandomGenerator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.MatrixUtility; +import com.jstarcraft.ai.math.algorithm.probability.QuantityProbability; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.model.exception.ModelException; + +/** + * + * BPMF推荐器 + * + *
+ * Bayesian Probabilistic Matrix Factorization using Markov Chain Monte Carlo
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class BPMFModel extends MatrixFactorizationModel { + + private float userMean, userWishart; + + private float itemMean, itemWishart; + + private float userBeta, itemBeta; + + private float rateSigma; + + private int gibbsIterations; + + private DenseMatrix[] userMatrixes; + + private DenseMatrix[] itemMatrixes; + + private QuantityProbability normalDistribution; + private QuantityProbability[] userGammaDistributions; + private QuantityProbability[] itemGammaDistributions; + + private class HyperParameter { + + // 缓存 + private float[] thisVectorCache; + private float[] thatVectorCache; + private float[] thisMatrixCache; + private float[] thatMatrixCache; + + private DenseVector factorMeans; + + private DenseMatrix factorVariances; + + private DenseVector randoms; + + private DenseVector outerMeans, innerMeans; + + private DenseMatrix covariance, cholesky, inverse, transpose, gaussian, gamma, wishart, copy; + + HyperParameter(int cache, DenseMatrix factors) { + if (cache < factorSize) { + cache = factorSize; + } + thisVectorCache = new float[cache]; + thisMatrixCache = new float[cache * factorSize]; + thatVectorCache = new float[cache]; + thatMatrixCache = new float[cache * factorSize]; + + factorMeans = DenseVector.valueOf(factorSize); + float scale = factors.getRowSize(); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + factorMeans.setValue(factorIndex, factors.getColumnVector(factorIndex).getSum(false) / scale); + } + outerMeans = DenseVector.valueOf(factors.getRowSize()); + innerMeans = DenseVector.valueOf(factors.getRowSize()); + covariance = DenseMatrix.valueOf(factorSize, factorSize); + + cholesky = DenseMatrix.valueOf(factorSize, factorSize); + + inverse = DenseMatrix.valueOf(factorSize, factorSize); + transpose = DenseMatrix.valueOf(factorSize, factorSize); + + randoms = DenseVector.valueOf(factorSize); + gaussian = DenseMatrix.valueOf(factorSize, factorSize); + gamma = DenseMatrix.valueOf(factorSize, factorSize); + wishart = DenseMatrix.valueOf(factorSize, factorSize); + + copy = DenseMatrix.valueOf(factorSize, factorSize); + + factorVariances = MatrixUtility.inverse(MatrixUtility.covariance(factors, outerMeans, innerMeans, covariance), copy, inverse); + } + + /** + * 取样 + * + * @param hyperParameter + * @param factors + * @param normalMu + * @param normalBeta + * @param wishartScale + * @return + * @throws ModelException + */ + private void sampleParameter(QuantityProbability[] gammaDistributions, DenseMatrix factors, float normalMu, float normalBeta, float wishartScale) throws ModelException { + int rowSize = factors.getRowSize(); + int columnSize = factors.getColumnSize(); + // 重复利用内存. + DenseVector meanCache = DenseVector.valueOf(factorSize, thisVectorCache); + float scale = factors.getRowSize(); + meanCache.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(factors.getColumnVector(index).getSum(false) / scale); + }); + float beta = normalBeta + rowSize; + DenseMatrix populationVariance = MatrixUtility.covariance(factors, outerMeans, innerMeans, covariance); + wishart.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = 0F; + if (row == column) { + value = wishartScale; + } + value += populationVariance.getValue(row, column) * rowSize; + value += (normalMu - meanCache.getValue(row)) * (normalMu - meanCache.getValue(column)) * (normalBeta * rowSize / beta); + scalar.setValue(value); + }); + DenseMatrix wishartMatrix = wishart; + wishartMatrix = MatrixUtility.inverse(wishartMatrix, copy, inverse); + wishartMatrix.addMatrix(transpose.copyMatrix(wishartMatrix, true), false).scaleValues(0.5F); + wishartMatrix = MatrixUtility.wishart(wishartMatrix, normalDistribution, gammaDistributions, randoms, cholesky, gaussian, gamma, transpose, wishart); + if (wishartMatrix != null) { + factorVariances = wishartMatrix; + } + DenseMatrix normalVariance = MatrixUtility.cholesky(MatrixUtility.inverse(factorVariances, copy, inverse).scaleValues(normalBeta), cholesky); + if (normalVariance != null) { + randoms.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(normalDistribution.sample().floatValue()); + }); + factorMeans.dotProduct(normalVariance, false, randoms, MathCalculator.SERIAL).iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value + (normalMu * normalBeta + meanCache.getValue(index) * rowSize) * (1F / beta)); + }); + } + } + + /** + * 更新 + * + * @param factorMatrix + * @param scoreVector + * @param hyperParameter + * @return + * @throws ModelException + */ + private void updateParameter(DenseMatrix factorMatrix, SparseVector scoreVector, DenseVector factorVector) throws ModelException { + int size = scoreVector.getElementSize(); + // 重复利用内存. + DenseMatrix factorCache = DenseMatrix.valueOf(size, factorSize, thisMatrixCache); + MathVector meanCache = DenseVector.valueOf(size, thisVectorCache); + int index = 0; + for (VectorScalar term : scoreVector) { + meanCache.setValue(index, term.getValue() - meanScore); + MathVector vector = factorMatrix.getRowVector(term.getIndex()); + factorCache.getRowVector(index).iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(vector.getValue(scalar.getIndex())); + }); + index++; + } + transpose.dotProduct(factorCache, true, factorCache, false, MathCalculator.SERIAL); + transpose.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue(value * rateSigma + factorVariances.getValue(row, column)); + }); + DenseMatrix covariance = transpose; + covariance = MatrixUtility.inverse(covariance, copy, inverse); + // 重复利用内存. + meanCache = DenseVector.valueOf(factorCache.getColumnSize(), thatVectorCache).dotProduct(factorCache, true, meanCache, MathCalculator.SERIAL); + meanCache.scaleValues(rateSigma); + // 重复利用内存. + meanCache.addVector(DenseVector.valueOf(factorVariances.getRowSize(), thisVectorCache).dotProduct(factorVariances, false, factorMeans, MathCalculator.SERIAL)); + // 重复利用内存. + meanCache = DenseVector.valueOf(covariance.getRowSize(), thisVectorCache).dotProduct(covariance, false, meanCache, MathCalculator.SERIAL); + covariance = MatrixUtility.cholesky(covariance, cholesky); + if (covariance != null) { + randoms.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(normalDistribution.sample().floatValue()); + }); + factorVector.dotProduct(covariance, false, randoms, MathCalculator.SERIAL).addVector(meanCache); + } else { + factorVector.setValues(0F); + } + } + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + userMean = configuration.getFloat("recommender.recommender.user.mu", 0F); + userBeta = configuration.getFloat("recommender.recommender.user.beta", 1F); + userWishart = configuration.getFloat("recommender.recommender.user.wishart.scale", 1F); + + itemMean = configuration.getFloat("recommender.recommender.item.mu", 0F); + itemBeta = configuration.getFloat("recommender.recommender.item.beta", 1F); + itemWishart = configuration.getFloat("recommender.recommender.item.wishart.scale", 1F); + + rateSigma = configuration.getFloat("recommender.recommender.rating.sigma", 2F); + + gibbsIterations = configuration.getInteger("recommender.recommender.iterations.gibbs", 1); + + userMatrixes = new DenseMatrix[epocheSize - 1]; + itemMatrixes = new DenseMatrix[epocheSize - 1]; + + normalDistribution = new QuantityProbability(JDKRandomGenerator.class, factorSize, NormalDistribution.class, 0D, 1D); + userGammaDistributions = new QuantityProbability[factorSize]; + itemGammaDistributions = new QuantityProbability[factorSize]; + for (int index = 0; index < factorSize; index++) { + userGammaDistributions[index] = new QuantityProbability(JDKRandomGenerator.class, index, GammaDistribution.class, (userSize + factorSize - (index + 1D)) / 2D, 2D); + itemGammaDistributions[index] = new QuantityProbability(JDKRandomGenerator.class, index, GammaDistribution.class, (itemSize + factorSize - (index + 1D)) / 2D, 2D); + } + } + + @Override + protected void doPractice() { + int cacheSize = 0; + SparseVector[] userVectors = new SparseVector[userSize]; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + cacheSize = cacheSize < userVector.getElementSize() ? userVector.getElementSize() : cacheSize; + userVectors[userIndex] = userVector; + } + + SparseVector[] itemVectors = new SparseVector[itemSize]; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + cacheSize = cacheSize < itemVector.getElementSize() ? itemVector.getElementSize() : cacheSize; + itemVectors[itemIndex] = itemVector; + } + + // TODO 此处考虑重构 + HyperParameter userParameter = new HyperParameter(cacheSize, userFactors); + HyperParameter itemParameter = new HyperParameter(cacheSize, itemFactors); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + userParameter.sampleParameter(userGammaDistributions, userFactors, userMean, userBeta, userWishart); + itemParameter.sampleParameter(itemGammaDistributions, itemFactors, itemMean, itemBeta, itemWishart); + for (int gibbsIteration = 0; gibbsIteration < gibbsIterations; gibbsIteration++) { + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector scoreVector = userVectors[userIndex]; + if (scoreVector.getElementSize() == 0) { + continue; + } + userParameter.updateParameter(itemFactors, scoreVector, userFactors.getRowVector(userIndex)); + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector scoreVector = itemVectors[itemIndex]; + if (scoreVector.getElementSize() == 0) { + continue; + } + itemParameter.updateParameter(userFactors, scoreVector, itemFactors.getRowVector(itemIndex)); + } + } + + if (epocheIndex > 0) { + userMatrixes[epocheIndex - 1] = DenseMatrix.copyOf(userFactors); + itemMatrixes[epocheIndex - 1] = DenseMatrix.copyOf(itemFactors); + } + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + float value = 0F; + for (int iterationStep = 0; iterationStep < epocheSize - 1; iterationStep++) { + DenseVector userVector = userMatrixes[iterationStep].getRowVector(userIndex); + DenseVector itemVector = itemMatrixes[iterationStep].getRowVector(itemIndex); + value = (value * (iterationStep) + meanScore + scalar.dotProduct(userVector, itemVector).getValue()) / (iterationStep + 1); + } + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BUCMRatingModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BUCMRatingModel.java new file mode 100644 index 0000000..09a3211 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BUCMRatingModel.java @@ -0,0 +1,40 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Map.Entry; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.rns.model.collaborative.BUCMModel; + +/** + * + * BUCM推荐器 + * + *
+ * Bayesian User Community Model
+ * Modeling Item Selection and Relevance for Accurate Recommendations
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class BUCMRatingModel extends BUCMModel { + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = 0F, probabilities = 0F; + for (Entry term : scoreIndexes.entrySet()) { + float score = term.getKey(); + float probability = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + probability += userTopicProbabilities.getValue(userIndex, topicIndex) * topicItemProbabilities.getValue(topicIndex, itemIndex) * topicItemScoreProbabilities[topicIndex][itemIndex][term.getValue()]; + } + value += probability * score; + probabilities += probability; + } + instance.setQuantityMark(value / probabilities); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BiasedMFModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BiasedMFModel.java new file mode 100644 index 0000000..ea04f05 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/BiasedMFModel.java @@ -0,0 +1,116 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * BiasedMF推荐器 + * + *
+ * Biased Matrix Factorization
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class BiasedMFModel extends MatrixFactorizationModel { + /** + * bias regularization + */ + protected float regBias; + + /** + * user biases + */ + protected DenseVector userBiases; + + /** + * user biases + */ + protected DenseVector itemBiases; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + regBias = configuration.getFloat("recommender.bias.regularization", 0.01F); + + // initialize the userBiased and itemBiased + userBiases = DenseVector.valueOf(userSize); + userBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + } + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); // user userIdx + int itemIndex = term.getColumn(); // item itemIdx + float score = term.getValue(); // real rating on item + // itemIdx rated by user + // userIdx + float predict = predict(userIndex, itemIndex); + float error = score - predict; + totalError += error * error; + + // update user and item bias + float userBias = userBiases.getValue(userIndex); + userBiases.shiftValue(userIndex, learnRatio * (error - regBias * userBias)); + totalError += regBias * userBias * userBias; + float itemBias = itemBiases.getValue(itemIndex); + itemBiases.shiftValue(itemIndex, learnRatio * (error - regBias * itemBias)); + totalError += regBias * itemBias * itemBias; + + // update user and item factors + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float itemFactor = itemFactors.getValue(itemIndex, factorIndex); + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (error * itemFactor - userRegularization * userFactor)); + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * (error * userFactor - itemRegularization * itemFactor)); + totalError += userRegularization * userFactor * userFactor + itemRegularization * itemFactor * itemFactor; + } + } + + totalError *= 0.5D; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + @Override + protected float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userFactors.getRowVector(userIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + float value = scalar.dotProduct(userVector, itemVector).getValue(); + value += meanScore + userBiases.getValue(userIndex) + itemBiases.getValue(itemIndex); + return value; + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/CCDModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/CCDModel.java new file mode 100644 index 0000000..3441613 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/CCDModel.java @@ -0,0 +1,102 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Date; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.StringUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * CCD推荐器 + * + *
+ * Large-Scale Parallel Collaborative Filtering for the Netflix Prize
+ * http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class CCDModel extends MatrixFactorizationModel { + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + userFactors = DenseMatrix.valueOf(userSize, factorSize); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + } + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = 0F; + float numerator = 0F; + float denominator = 0F; + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + numerator += (term.getValue() + userFactors.getValue(userIndex, factorIndex) * itemFactors.getValue(itemIndex, factorIndex)) * itemFactors.getValue(itemIndex, factorIndex); + denominator += itemFactors.getValue(itemIndex, factorIndex) * itemFactors.getValue(itemIndex, factorIndex); + } + userFactor = numerator / (denominator + userRegularization); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + term.setValue(term.getValue() - (userFactor - userFactors.getValue(userIndex, factorIndex)) * itemFactors.getValue(itemIndex, factorIndex)); + } + userFactors.setValue(userIndex, factorIndex, userFactor); + } + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float itemFactor = 0F; + float numerator = 0F; + float denominator = 0F; + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + numerator += (term.getValue() + userFactors.getValue(userIndex, factorIndex) * itemFactors.getValue(itemIndex, factorIndex)) * userFactors.getValue(userIndex, factorIndex); + denominator += userFactors.getValue(userIndex, factorIndex) * userFactors.getValue(userIndex, factorIndex); + } + itemFactor = numerator / (denominator + itemRegularization); + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + term.setValue(term.getValue() - (itemFactor - itemFactors.getValue(itemIndex, factorIndex)) * userFactors.getValue(userIndex, factorIndex)); + } + itemFactors.setValue(itemIndex, factorIndex, itemFactor); + } + } + logger.info(StringUtility.format("{} runs at iter {}/{} {}", this.getClass().getSimpleName(), epocheIndex, epocheSize, new Date())); + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + float score = scalar.dotProduct(userFactors.getRowVector(userIndex), itemFactors.getRowVector(itemIndex)).getValue(); + if (score == 0F) { + score = meanScore; + } + instance.setQuantityMark(score); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/FFMModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/FFMModel.java new file mode 100644 index 0000000..120c886 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/FFMModel.java @@ -0,0 +1,150 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.FactorizationMachineModel; + +/** + * + * FFM推荐器 + * + *
+ * Field Aware Factorization Machines for CTR Prediction
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class FFMModel extends FactorizationMachineModel { + + /** + * record the + */ + private int[] featureOrders; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + // Matrix for p * (factor * filed) + // TODO 此处应该还是稀疏 + featureFactors = DenseMatrix.valueOf(featureSize, factorSize * marker.getQualityOrder()); + featureFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + // init the map for feature of filed + featureOrders = new int[featureSize]; + int count = 0; + for (int orderIndex = 0, orderSize = dimensionSizes.length; orderIndex < orderSize; orderIndex++) { + int size = dimensionSizes[orderIndex]; + for (int index = 0; index < size; index++) { + featureOrders[count + index] = orderIndex; + } + count += size; + } + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + int outerIndex = 0; + int innerIndex = 0; + float outerValue = 0F; + float innerValue = 0F; + float oldWeight = 0F; + float newWeight = 0F; + float oldFactor = 0F; + float newFactor = 0F; + for (DataInstance sample : marker) { + // TODO 因为每次的data都是1,可以考虑避免重复构建featureVector. + MathVector featureVector = getFeatureVector(sample); + float score = sample.getQuantityMark(); + float predict = predict(scalar, featureVector); + float error = predict - score; + totalError += error * error; + + // global bias + totalError += biasRegularization * globalBias * globalBias; + + // update w0 + float hW0 = 1; + float gradW0 = error * hW0 + biasRegularization * globalBias; + globalBias += -learnRatio * gradW0; + + // 1-way interactions + for (VectorScalar outerTerm : featureVector) { + outerIndex = outerTerm.getIndex(); + innerIndex = 0; + oldWeight = weightVector.getValue(outerIndex); + newWeight = outerTerm.getValue(); + newWeight = error * newWeight + weightRegularization * oldWeight; + weightVector.shiftValue(outerIndex, -learnRatio * newWeight); + totalError += weightRegularization * oldWeight * oldWeight; + outerValue = outerTerm.getValue(); + innerValue = 0F; + // 2-way interactions + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + oldFactor = featureFactors.getValue(outerIndex, featureOrders[outerIndex] + factorIndex); + newFactor = 0F; + for (VectorScalar innerTerm : featureVector) { + innerIndex = innerTerm.getIndex(); + innerValue = innerTerm.getValue(); + if (innerIndex != outerIndex) { + newFactor += outerValue * featureFactors.getValue(innerIndex, featureOrders[outerIndex] + factorIndex) * innerValue; + } + } + newFactor = error * newFactor + factorRegularization * oldFactor; + featureFactors.shiftValue(outerIndex, featureOrders[outerIndex] + factorIndex, -learnRatio * newFactor); + totalError += factorRegularization * oldFactor * oldFactor; + } + } + } + + totalError *= 0.5; + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + } + + @Override + protected float predict(DefaultScalar scalar, MathVector featureVector) { + float value = 0F; + // global bias + value += globalBias; + // 1-way interaction + value += scalar.dotProduct(weightVector, featureVector).getValue(); + int outerIndex = 0; + int innerIndex = 0; + float outerValue = 0F; + float innerValue = 0F; + // 2-way interaction + for (int featureIndex = 0; featureIndex < factorSize; featureIndex++) { + for (VectorScalar outerVector : featureVector) { + outerIndex = outerVector.getIndex(); + outerValue = outerVector.getValue(); + for (VectorScalar innerVector : featureVector) { + innerIndex = innerVector.getIndex(); + innerValue = innerVector.getValue(); + if (outerIndex != innerIndex) { + value += featureFactors.getValue(outerIndex, featureOrders[innerIndex] + featureIndex) * featureFactors.getValue(innerIndex, featureOrders[outerIndex] + featureIndex) * outerValue * innerValue; + } + } + } + } + return value; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/FMALSModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/FMALSModel.java new file mode 100644 index 0000000..295f05e --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/FMALSModel.java @@ -0,0 +1,188 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.FactorizationMachineModel; + +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * + * FM ALS推荐器 + * + *
+ * Factorization Machines via Alternating Least Square
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class FMALSModel extends FactorizationMachineModel { + + /** + * train appender matrix + */ + private SparseMatrix featureMatrix; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // init Q + // TODO 此处为rateFactors + actionFactors = DenseMatrix.valueOf(actionSize, factorSize); + + // construct training appender matrix + HashMatrix table = new HashMatrix(true, actionSize, featureSize, new Long2FloatRBTreeMap()); + int index = 0; + int order = marker.getQualityOrder(); + for (DataInstance sample : model) { + int count = 0; + for (int orderIndex = 0; orderIndex < order; orderIndex++) { + table.setValue(index, count + sample.getQualityFeature(orderIndex), 1F); + count += dimensionSizes[orderIndex]; + } + index++; + } + // TODO 考虑重构(.此处似乎就是FactorizationMachineRecommender.getFeatureVector); + featureMatrix = SparseMatrix.valueOf(actionSize, featureSize, table); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + // precomputing Q and errors, for efficiency + DenseVector errorVector = DenseVector.valueOf(actionSize); + int index = 0; + for (DataInstance sample : marker) { + // TODO 因为每次的data都是1,可以考虑避免重复构建featureVector. + MathVector featureVector = getFeatureVector(sample); + float score = sample.getQuantityMark(); + float predict = predict(scalar, featureVector); + + float error = score - predict; + errorVector.setValue(index, error); + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float sum = 0F; + for (VectorScalar vectorTerm : featureVector) { + sum += featureFactors.getValue(vectorTerm.getIndex(), factorIndex) * vectorTerm.getValue(); + } + actionFactors.setValue(index, factorIndex, sum); + } + index++; + } + + /** + * parameter optimized by using formula in [1]. errors updated by using formula: + * error_new = error_old + theta_old*h_old - theta_new * h_new; reference: [1]. + * Rendle, Steffen, "Factorization Machines with libFM." ACM Transactions on + * Intelligent Systems and Technology, 2012. + */ + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + // global bias + float numerator = 0F; + float denominator = 0F; + + for (int scoreIndex = 0; scoreIndex < actionSize; scoreIndex++) { + // TODO 因为此处相当与迭代trainTensor的featureVector,所以h_theta才会是1D. + float h_theta = 1F; + numerator += globalBias * h_theta * h_theta + h_theta * errorVector.getValue(scoreIndex); + denominator += h_theta; + } + denominator += biasRegularization; + float bias = numerator / denominator; + // update errors + for (int scoreIndex = 0; scoreIndex < actionSize; scoreIndex++) { + float oldError = errorVector.getValue(scoreIndex); + float newError = oldError + (globalBias - bias); + errorVector.setValue(scoreIndex, newError); + totalError += oldError * oldError; + } + + // update w0 + globalBias = bias; + totalError += biasRegularization * globalBias * globalBias; + + // 1-way interactions + for (int featureIndex = 0; featureIndex < featureSize; featureIndex++) { + float oldWeight = weightVector.getValue(featureIndex); + numerator = 0F; + denominator = 0F; + // TODO 考虑重构 + SparseVector featureVector = featureMatrix.getColumnVector(featureIndex); + for (VectorScalar vectorTerm : featureVector) { + int scoreIndex = vectorTerm.getIndex(); + float h_theta = vectorTerm.getValue(); + numerator += oldWeight * h_theta * h_theta + h_theta * errorVector.getValue(scoreIndex); + denominator += h_theta * h_theta; + } + denominator += weightRegularization; + float newWeight = numerator / denominator; + // update errors + for (VectorScalar vectorTerm : featureVector) { + int scoreIndex = vectorTerm.getIndex(); + float oldError = errorVector.getValue(scoreIndex); + float newError = oldError + (oldWeight - newWeight) * vectorTerm.getValue(); + errorVector.setValue(scoreIndex, newError); + } + // update W + weightVector.setValue(featureIndex, newWeight); + totalError += weightRegularization * oldWeight * oldWeight; + } + + // 2-way interactions + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + for (int featureIndex = 0; featureIndex < featureSize; featureIndex++) { + float oldValue = featureFactors.getValue(featureIndex, factorIndex); + numerator = 0F; + denominator = 0F; + SparseVector featureVector = featureMatrix.getColumnVector(featureIndex); + for (VectorScalar vectorTerm : featureVector) { + int scoreIndex = vectorTerm.getIndex(); + float x_val = vectorTerm.getValue(); + float h_theta = x_val * (actionFactors.getValue(scoreIndex, factorIndex) - oldValue * x_val); + numerator += oldValue * h_theta * h_theta + h_theta * errorVector.getValue(scoreIndex); + denominator += h_theta * h_theta; + } + denominator += factorRegularization; + float newValue = numerator / denominator; + // update errors and Q + for (VectorScalar vectorTerm : featureVector) { + int scoreIndex = vectorTerm.getIndex(); + float x_val = vectorTerm.getValue(); + float oldScore = actionFactors.getValue(scoreIndex, factorIndex); + float newScore = oldScore + (newValue - oldValue) * x_val; + float h_theta_old = x_val * (oldScore - oldValue * x_val); + float h_theta_new = x_val * (newScore - newValue * x_val); + float oldError = errorVector.getValue(scoreIndex); + float newError = oldError + oldValue * h_theta_old - newValue * h_theta_new; + errorVector.setValue(scoreIndex, newError); + actionFactors.setValue(scoreIndex, factorIndex, newScore); + } + + // update V + featureFactors.setValue(featureIndex, factorIndex, newValue); + totalError += factorRegularization * oldValue * oldValue; + } + } + + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + } + +} \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/FMSGDModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/FMSGDModel.java new file mode 100644 index 0000000..24886ae --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/FMSGDModel.java @@ -0,0 +1,88 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.FactorizationMachineModel; + +/** + * + * FM SGD推荐器 + * + *
+ * Factorization Machines via Stochastic Gradient Descent with Square Loss
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class FMSGDModel extends FactorizationMachineModel { + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (DataInstance sample : marker) { + // TODO 因为每次的data都是1,可以考虑避免重复构建featureVector. + MathVector featureVector = getFeatureVector(sample); + float score = sample.getQuantityMark(); + float predict = predict(scalar, featureVector); + + float error = predict - score; + totalError += error * error; + + // global bias + totalError += biasRegularization * globalBias * globalBias; + + // TODO 因为此处相当与迭代trainTensor的featureVector,所以hW0才会是1D. + float hW0 = 1F; + float bias = error * hW0 + biasRegularization * globalBias; + + // update w0 + globalBias += -learnRatio * bias; + + // 1-way interactions + for (VectorScalar outerTerm : featureVector) { + int outerIndex = outerTerm.getIndex(); + float oldWeight = weightVector.getValue(outerIndex); + float featureWeight = outerTerm.getValue(); + float newWeight = error * featureWeight + weightRegularization * oldWeight; + weightVector.shiftValue(outerIndex, -learnRatio * newWeight); + totalError += weightRegularization * oldWeight * oldWeight; + // 2-way interactions + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float oldValue = featureFactors.getValue(outerIndex, factorIndex); + float newValue = 0F; + for (VectorScalar innerTerm : featureVector) { + int innerIndex = innerTerm.getIndex(); + if (innerIndex != outerIndex) { + newValue += featureWeight * featureFactors.getValue(innerIndex, factorIndex) * innerTerm.getValue(); + } + } + newValue = error * newValue + factorRegularization * oldValue; + featureFactors.shiftValue(outerIndex, factorIndex, -learnRatio * newValue); + totalError += factorRegularization * oldValue * oldValue; + } + } + } + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + } + +} \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/GPLSAModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/GPLSAModel.java new file mode 100644 index 0000000..b4e0e3e --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/GPLSAModel.java @@ -0,0 +1,224 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.MathUtility; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.table.SparseTable; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.Float2FloatKeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; +import com.jstarcraft.rns.utility.GaussianUtility; + +import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap; + +/** + * + * GPLSA推荐器 + * + *
+ * Collaborative Filtering via Gaussian Probabilistic Latent Semantic Analysis
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class GPLSAModel extends ProbabilisticGraphicalModel { + + /* + * {user, item, {topic z, probability}} + */ + protected SparseTable probabilityTensor; + /* + * Conditional Probability: P(z|u) + */ + protected DenseMatrix userTopicProbabilities; + /* + * Conditional Probability: P(v|y,z) + */ + protected DenseMatrix itemMus, itemSigmas; + /* + * regularize ratings + */ + protected DenseVector userMus, userSigmas; + /* + * smoothing weight + */ + protected float smoothWeight; + /* + * tempered EM parameter beta, suggested by Wu Bin + */ + protected float beta; + /* + * small value for initialization + */ + protected static float smallValue = MathUtility.EPSILON; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // Initialize users' conditional probabilities + userTopicProbabilities = DenseMatrix.valueOf(userSize, factorSize); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + DenseVector probabilityVector = userTopicProbabilities.getRowVector(userIndex); + probabilityVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomInteger(factorSize) + 1); + }); + probabilityVector.scaleValues(1F / probabilityVector.getSum(false)); + } + + Float2FloatKeyValue keyValue = scoreMatrix.getVariance(); + float mean = keyValue.getKey(); + float variance = keyValue.getValue() / scoreMatrix.getElementSize(); + + userMus = DenseVector.valueOf(userSize); + userSigmas = DenseVector.valueOf(userSize); + smoothWeight = configuration.getInteger("recommender.recommender.smoothWeight"); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + int size = userVector.getElementSize(); + if (size < 1) { + continue; + } + float mu = (userVector.getSum(false) + smoothWeight * mean) / (size + smoothWeight); + userMus.setValue(userIndex, mu); + float sigma = userVector.getVariance(mu); + sigma += smoothWeight * variance; + sigma = (float) Math.sqrt(sigma / (size + smoothWeight)); + userSigmas.setValue(userIndex, sigma); + } + + // Initialize Q + // TODO 重构 + probabilityTensor = new SparseTable<>(true, userSize, itemSize, new Int2ObjectRBTreeMap<>()); + + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + score = (score - userMus.getValue(userIndex)) / userSigmas.getValue(userIndex); + term.setValue(score); + probabilityTensor.setValue(userIndex, itemIndex, new float[factorSize]); + } + + itemMus = DenseMatrix.valueOf(itemSize, factorSize); + itemSigmas = DenseMatrix.valueOf(itemSize, factorSize); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + int size = itemVector.getElementSize(); + if (size < 1) { + continue; + } + float mu = itemVector.getSum(false) / size; + float sigma = itemVector.getVariance(mu); + sigma = (float) Math.sqrt(sigma / size); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + itemMus.setValue(itemIndex, topicIndex, mu + smallValue * RandomUtility.randomFloat(1F)); + itemSigmas.setValue(itemIndex, topicIndex, sigma + smallValue * RandomUtility.randomFloat(1F)); + } + } + } + + @Override + protected void eStep() { + // variational inference to compute Q + float[] numerators = new float[factorSize]; + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + float denominator = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float pdf = GaussianUtility.probabilityDensity(score, itemMus.getValue(itemIndex, topicIndex), itemSigmas.getValue(itemIndex, topicIndex)); + float value = (float) Math.pow(userTopicProbabilities.getValue(userIndex, topicIndex) * pdf, beta); // Tempered + // EM + numerators[topicIndex] = value; + denominator += value; + } + float[] probabilities = probabilityTensor.getValue(userIndex, itemIndex); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float probability = (denominator > 0 ? numerators[topicIndex] / denominator : 0); + probabilities[topicIndex] = probability; + } + } + } + + @Override + protected void mStep() { + float[] numerators = new float[factorSize]; + // theta_u,z + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() < 1) { + continue; + } + float denominator = 0F; + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + float[] probabilities = probabilityTensor.getValue(userIndex, itemIndex); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + numerators[topicIndex] = probabilities[topicIndex]; + denominator += numerators[topicIndex]; + } + } + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + userTopicProbabilities.setValue(userIndex, topicIndex, numerators[topicIndex] / denominator); + } + } + + // topicItemMu, topicItemSigma + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + if (itemVector.getElementSize() < 1) { + continue; + } + float numerator = 0F, denominator = 0F; + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + float score = term.getValue(); + float[] probabilities = probabilityTensor.getValue(userIndex, itemIndex); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + float probability = probabilities[topicIndex]; + numerator += score * probability; + denominator += probability; + } + } + float mu = denominator > 0F ? numerator / denominator : 0F; + numerator = 0F; + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + float score = term.getValue(); + float[] probabilities = probabilityTensor.getValue(userIndex, itemIndex); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + double probability = probabilities[topicIndex]; + numerator += Math.pow(score - mu, 2) * probability; + } + } + float sigma = (float) (denominator > 0F ? Math.sqrt(numerator / denominator) : 0F); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + itemMus.setValue(itemIndex, topicIndex, mu); + itemSigmas.setValue(itemIndex, topicIndex, sigma); + } + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float sum = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + sum += userTopicProbabilities.getValue(userIndex, topicIndex) * itemMus.getValue(itemIndex, topicIndex); + } + instance.setQuantityMark(userMus.getValue(userIndex) + userSigmas.getValue(userIndex) * sum); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/IRRGModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/IRRGModel.java new file mode 100644 index 0000000..a04bad0 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/IRRGModel.java @@ -0,0 +1,429 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Table; +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * + * IRRG推荐器 + * + *
+ * Exploiting Implicit Item Relationships for Recommender Systems
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class IRRGModel extends MatrixFactorizationModel { + + /** item relationship regularization coefficient */ + private float correlationRegularization; + + /** adjust the reliability */ + // TODO 修改为配置. + private float reliability = 50F; + + /** k nearest neighborhoods */ + // TODO 修改为配置. + private int neighborSize = 50; + + /** store co-occurence between two items. */ + @Deprecated + private Table itemCount = HashBasedTable.create(); + + /** store item-to-item AR */ + @Deprecated + private Table itemCorrsAR = HashBasedTable.create(); + + /** store sorted item-to-item AR */ + @Deprecated + private Table itemCorrsAR_Sorted = HashBasedTable.create(); + + /** store the complementary item-to-item AR */ + @Deprecated + private Table itemCorrsAR_added = HashBasedTable.create(); + + /** store group-to-item AR */ + @Deprecated + private Map, Float>>> itemCorrsGAR = new HashMap<>(); + + private SparseMatrix complementMatrix; + + /** store sorted group-to-item AR */ + private Map itemCorrsGAR_Sorted = new HashMap<>(); + + // TODO 临时性表格,用于代替trainMatrix.getTermValue. + @Deprecated + Table dataTable = HashBasedTable.create(); + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + dataTable.put(userIndex, itemIndex, term.getValue()); + } + + correlationRegularization = configuration.getFloat("recommender.alpha"); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(0.8F)); + }); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(0.8F)); + }); + + computeAssociationRuleByItem(); + sortAssociationRuleByItem(); + computeAssociationRuleByGroup(); + sortAssociationRuleByGroup(); + complementAssociationRule(); + complementMatrix = SparseMatrix.valueOf(itemSize, itemSize, itemCorrsAR_added); + } + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + + DenseMatrix userDeltas = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix itemDeltas = DenseMatrix.valueOf(itemSize, factorSize); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + if (score <= 0F) { + continue; + } + float predict = super.predict(userIndex, itemIndex); + float error = LogisticUtility.getValue(predict) - (score - minimumScore) / (maximumScore - minimumScore); + float csgd = LogisticUtility.getGradient(predict) * error; + + totalError += error * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float itemFactor = itemFactors.getValue(itemIndex, factorIndex); + userDeltas.shiftValue(userIndex, factorIndex, csgd * itemFactor + userRegularization * userFactor); + itemDeltas.shiftValue(itemIndex, factorIndex, csgd * userFactor + itemRegularization * itemFactor); + totalError += userRegularization * userFactor * userFactor + itemRegularization * itemFactor * itemFactor; + } + } + + for (int leftItemIndex = 0; leftItemIndex < itemSize; leftItemIndex++) { // complementary + // item-to-item + // AR + SparseVector itemVector = complementMatrix.getColumnVector(leftItemIndex); + for (VectorScalar term : itemVector) { + int rightItemIndex = term.getIndex(); + float skj = term.getValue(); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float ekj = itemFactors.getValue(leftItemIndex, factorIndex) - itemFactors.getValue(rightItemIndex, factorIndex); + itemDeltas.shiftValue(leftItemIndex, factorIndex, correlationRegularization * skj * ekj); + totalError += correlationRegularization * skj * ekj * ekj; + } + } + itemVector = complementMatrix.getRowVector(leftItemIndex); + for (VectorScalar term : itemVector) { + int rightItemIndex = term.getIndex(); + float sjg = term.getValue(); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float ejg = itemFactors.getValue(leftItemIndex, factorIndex) - itemFactors.getValue(rightItemIndex, factorIndex); + itemDeltas.shiftValue(leftItemIndex, factorIndex, correlationRegularization * sjg * ejg); + } + } + } + + // group-to-item AR + for (Entry leftKeyValue : itemCorrsGAR_Sorted.entrySet()) { + int leftItemIndex = leftKeyValue.getKey(); + SparseMatrix leftTable = leftKeyValue.getValue(); + for (MatrixScalar term : leftTable) { + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float egkj = (float) (itemFactors.getValue(leftItemIndex, factorIndex) - (itemFactors.getValue(term.getRow(), factorIndex) + itemFactors.getValue(term.getColumn(), factorIndex)) / Math.sqrt(2F)); + float egkj_1 = correlationRegularization * term.getValue() * egkj; + itemDeltas.shiftValue(leftItemIndex, factorIndex, egkj_1); + totalError += egkj_1 * egkj; + } + } + for (Entry rightKeyValue : itemCorrsGAR_Sorted.entrySet()) { + int rightItemIndex = rightKeyValue.getKey(); + if (rightItemIndex != leftItemIndex) { + SparseMatrix rightTable = rightKeyValue.getValue(); + SparseVector itemVector = rightTable.getRowVector(leftItemIndex); + for (VectorScalar term : itemVector) { + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float ejgk = (float) (itemFactors.getValue(rightItemIndex, factorIndex) - (itemFactors.getValue(leftItemIndex, factorIndex) + itemFactors.getValue(term.getIndex(), factorIndex)) / Math.sqrt(2F)); + float ejgk_1 = (float) (-correlationRegularization * term.getValue() * ejgk / Math.sqrt(2F)); + itemDeltas.shiftValue(leftItemIndex, factorIndex, ejgk_1); + } + } + } + } + } + + userFactors.addMatrix(userDeltas.scaleValues(-learnRatio), false); + itemFactors.addMatrix(itemDeltas.scaleValues(-learnRatio), false); + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float score = super.predict(userIndex, itemIndex); + score = LogisticUtility.getValue(score); + score = minimumScore + score * (maximumScore - minimumScore); + instance.setQuantityMark(score); + } + + /** + * 计算物品之间的关联规则 + */ + private void computeAssociationRuleByItem() { + // TODO 此处可以参考Abstract.getScoreList的相似度计算. + for (int leftItemIndex = 0; leftItemIndex < itemSize; leftItemIndex++) { + if (scoreMatrix.getColumnScope(leftItemIndex) == 0) { + continue; + } + SparseVector itemVector = scoreMatrix.getColumnVector(leftItemIndex); + int total = itemVector.getElementSize(); + for (int rightItemIndex = 0; rightItemIndex < itemSize; rightItemIndex++) { + if (leftItemIndex == rightItemIndex) { + continue; + } + float coefficient = 0F; + int count = 0; + for (VectorScalar term : itemVector) { + int userIndex = term.getIndex(); + if (dataTable.contains(userIndex, rightItemIndex)) { + count++; + } + } + float shrink = count / (count + reliability); + coefficient = shrink * count / total; + if (coefficient > 0F) { + itemCorrsAR.put(leftItemIndex, rightItemIndex, coefficient); + itemCount.put(leftItemIndex, rightItemIndex, count); + } + } + } + } + + /** + * 排序关联规则 + */ + private void sortAssociationRuleByItem() { + for (int leftItemIndex : itemCorrsAR.columnKeySet()) { + int size = itemCorrsAR.column(leftItemIndex).size(); + float temp[][] = new float[size][3]; + int flag = 0; + for (int rightItemIndex : itemCorrsAR.column(leftItemIndex).keySet()) { + temp[flag][0] = rightItemIndex; + temp[flag][1] = leftItemIndex; + temp[flag][2] = itemCorrsAR.get(rightItemIndex, leftItemIndex); + flag++; + } + if (size > neighborSize) { + for (int i = 0; i < neighborSize; i++) { // sort k nearest + // neighbors + for (int j = i + 1; j < size; j++) { + if (temp[i][2] < temp[j][2]) { + for (int k = 0; k < 3; k++) { + float trans = temp[i][k]; + temp[i][k] = temp[j][k]; + temp[j][k] = trans; + } + } + } + } + storeAssociationRule(neighborSize, temp); + } else { + storeAssociationRule(size, temp); + } + } + } + + /** + * 保存关联规则 + * + * @param size + * @param temp + */ + private void storeAssociationRule(int size, float temp[][]) { + for (int i = 0; i < size; i++) { + int leftItemIndex = (int) (temp[i][0]); + int rightItemIndex = (int) (temp[i][1]); + itemCorrsAR_Sorted.put(leftItemIndex, rightItemIndex, temp[i][2]); + } + } + + /** + * Find out itemsets which contain three items and store them into mylist. + */ + private void computeAssociationRuleByGroup() { + for (int groupIndex : itemCorrsAR.columnKeySet()) { + Integer[] itemIndexes = itemCorrsAR_Sorted.column(groupIndex).keySet().toArray(new Integer[] {}); + LinkedList> groupItemList = new LinkedList<>(); + for (int leftIndex = 0; leftIndex < itemIndexes.length - 1; leftIndex++) { + for (int rightIndex = leftIndex + 1; rightIndex < itemIndexes.length; rightIndex++) { + if (itemCount.contains(itemIndexes[leftIndex], itemIndexes[rightIndex])) { + groupItemList.add(new KeyValue<>(itemIndexes[leftIndex], itemIndexes[rightIndex])); + } + } + } + computeAssociationRuleByGroup(groupIndex, groupItemList); + } + } + + /** + * Compute group-to-item AR and store them into map itemCorrsGAR + */ + private void computeAssociationRuleByGroup(int groupIndex, LinkedList> itemList) { + List, Float>> coefficientList = new LinkedList<>(); + + for (KeyValue keyValue : itemList) { + int leftIndex = keyValue.getKey(); + int rightIndex = keyValue.getValue(); + SparseVector groupVector = scoreMatrix.getColumnVector(groupIndex); + int count = 0; + for (VectorScalar term : groupVector) { + int userIndex = term.getIndex(); + if (dataTable.contains(userIndex, leftIndex) && dataTable.contains(userIndex, rightIndex)) { + count++; + } + } + if (count > 0) { + float shrink = count / (count + reliability); + int co_bc = itemCount.get(leftIndex, rightIndex); + float coefficient = shrink * (count + 0F) / co_bc; + coefficientList.add(new KeyValue<>(keyValue, coefficient)); + } + } + itemCorrsGAR.put(groupIndex, new ArrayList<>(coefficientList)); + } + + /** + * Order group-to-item AR and store them into map itemCorrsGAR_Sorted + */ + private void sortAssociationRuleByGroup() { + for (int groupIndex : itemCorrsGAR.keySet()) { + List, Float>> list = itemCorrsGAR.get(groupIndex); + if (list.size() > neighborSize) { + Collections.sort(list, (left, right) -> { + return right.getValue().compareTo(left.getValue()); + }); + list = list.subList(0, neighborSize); + } + + HashMatrix groupTable = new HashMatrix(true, itemSize, itemSize, new Long2FloatRBTreeMap()); + for (KeyValue, Float> keyValue : list) { + int leftItemIndex = keyValue.getKey().getKey(); + int rightItemIndex = keyValue.getKey().getValue(); + float correlation = keyValue.getValue(); + groupTable.setValue(leftItemIndex, rightItemIndex, correlation); + } + itemCorrsGAR_Sorted.put(groupIndex, SparseMatrix.valueOf(itemSize, itemSize, groupTable)); + } + } + + /** + * Select item-to-item AR to complement group-to-item AR + */ + /** + * 选择物品关联规则补充分组关联规则. + */ + private void complementAssociationRule() { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + if (scoreMatrix.getColumnScope(itemIndex) == 0) { + continue; + } + SparseMatrix groupTable = itemCorrsGAR_Sorted.get(itemIndex); + if (groupTable != null) { + int groupSize = groupTable.getElementSize(); + if (groupSize < neighborSize) { + int complementSize = neighborSize - groupSize; + int itemSize = itemCorrsAR_Sorted.column(itemIndex).size(); + // TODO 使用KeyValue代替. + float[][] trans = new float[itemSize][2]; + if (itemSize > complementSize) { + int count = 0; + for (int id : itemCorrsAR_Sorted.column(itemIndex).keySet()) { + float value = itemCorrsAR_Sorted.get(id, itemIndex); + trans[count][0] = id; + trans[count][1] = value; + count++; + } + for (int x = 0; x < complementSize; x++) { + for (int y = x + 1; y < trans.length; y++) { + float x_value = trans[x][1]; + float y_value = trans[y][1]; + if (x_value < y_value) { + for (int z = 0; z < 2; z++) { + float tran = trans[x][z]; + trans[x][z] = trans[y][z]; + trans[y][z] = tran; + } + } + } + } + for (int x = 0; x < complementSize; x++) { + int id = (int) (trans[x][0]); + float value = trans[x][1]; + itemCorrsAR_added.put(id, itemIndex, value); + } + } else { + storeCAR(itemIndex); + } + } + } else { + storeCAR(itemIndex); + } + } + } + + /** + * Function to store complementary item-to-item AR into table itemCorrsAR_added. + * + * @param leftItemIndex + */ + private void storeCAR(int leftItemIndex) { + for (int rightItemIndex : itemCorrsAR_Sorted.column(leftItemIndex).keySet()) { + float value = itemCorrsAR_Sorted.get(rightItemIndex, leftItemIndex); + itemCorrsAR_added.put(rightItemIndex, leftItemIndex, value); + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/ItemKNNRatingModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/ItemKNNRatingModel.java new file mode 100644 index 0000000..3e71f9b --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/ItemKNNRatingModel.java @@ -0,0 +1,79 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Iterator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.rns.model.collaborative.ItemKNNModel; + +/** + * + * Item KNN推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class ItemKNNRatingModel extends ItemKNNModel { + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + SparseVector userVector = userVectors[userIndex]; + MathVector neighbors = itemNeighbors[itemIndex]; + if (userVector.getElementSize() == 0 || neighbors.getElementSize() == 0) { + instance.setQuantityMark(meanScore); + return; + } + + float sum = 0F, absolute = 0F; + int count = 0; + int leftCursor = 0, rightCursor = 0, leftSize = userVector.getElementSize(), rightSize = neighbors.getElementSize(); + Iterator leftIterator = userVector.iterator(); + VectorScalar leftTerm = leftIterator.next(); + Iterator rightIterator = neighbors.iterator(); + VectorScalar rightTerm = rightIterator.next(); + // 判断两个有序数组中是否存在相同的数字 + while (leftCursor < leftSize && rightCursor < rightSize) { + if (leftTerm.getIndex() == rightTerm.getIndex()) { + count++; + float correlation = rightTerm.getValue(); + float score = leftTerm.getValue(); + sum += correlation * (score - itemMeans.getValue(rightTerm.getIndex())); + absolute += Math.abs(correlation); + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + leftCursor++; + rightCursor++; + } else if (leftTerm.getIndex() > rightTerm.getIndex()) { + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + rightCursor++; + } else if (leftTerm.getIndex() < rightTerm.getIndex()) { + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + leftCursor++; + } + } + + if (count == 0) { + instance.setQuantityMark(meanScore); + return; + } + + instance.setQuantityMark(absolute > 0 ? itemMeans.getValue(itemIndex) + sum / absolute : meanScore); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/KernelSmoother.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/KernelSmoother.java new file mode 100644 index 0000000..f3081ef --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/KernelSmoother.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +/** + * 核平滑器 + * + *
+ * {@link LLORMAModel}
+ * 
+ * + * @author Birdy + * + */ +enum KernelSmoother { + + TRIANGULAR_KERNEL, UNIFORM_KERNEL, EPANECHNIKOV_KERNEL, GAUSSIAN_KERNEL; + + public float kernelize(float similarity, float width) { + float distance = 1F - similarity; + switch (this) { + case TRIANGULAR_KERNEL: + return Math.max(1F - distance / width, 0F); + case UNIFORM_KERNEL: + return distance < width ? 1F : 0F; + case EPANECHNIKOV_KERNEL: + return (float) Math.max(3F / 4F * (1F - Math.pow(distance / width, 2F)), 0F); + case GAUSSIAN_KERNEL: + return (float) (1F / Math.sqrt(2F * Math.PI) * Math.exp(-0.5F * Math.pow(distance / width, 2F))); + default: + return Math.max(1F - distance / width, 0F); + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/LDCCModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/LDCCModel.java new file mode 100644 index 0000000..0314b83 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/LDCCModel.java @@ -0,0 +1,291 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Map.Entry; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; +import com.jstarcraft.rns.utility.SampleUtility; + +import it.unimi.dsi.fastutil.ints.Int2IntRBTreeMap; + +/** + * + * LDCC推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class LDCCModel extends ProbabilisticGraphicalModel { + + // TODO 重构为稀疏矩阵? + private Int2IntRBTreeMap userTopics, itemTopics; // Zu, Zv + + private DenseMatrix userTopicTimes, itemTopicTimes; // Nui, Nvj + private DenseVector userScoreTimes, itemScoreTimes; // Nv + + private DenseMatrix topicTimes; + private DenseMatrix topicProbabilities; + + private DenseVector userProbabilities; + private DenseVector itemProbabilities; + + private int[][][] rateTopicTimes; + + private int numberOfUserTopics, numberOfItemTopics; + + private float userAlpha, itemAlpha, ratingBeta; + + private DenseMatrix userTopicProbabilities, itemTopicProbabilities; + private DenseMatrix userTopicSums, itemTopicSums; + private float[][][] rateTopicProbabilities, rateTopicSums; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + numberOfStatistics = 0; + + numberOfUserTopics = configuration.getInteger("recommender.pgm.number.users", 10); + numberOfItemTopics = configuration.getInteger("recommender.pgm.number.items", 10); + + userAlpha = configuration.getFloat("recommender.pgm.user.alpha", 1F / numberOfUserTopics); + itemAlpha = configuration.getFloat("recommender.pgm.item.alpha", 1F / numberOfItemTopics); + ratingBeta = configuration.getFloat("recommender.pgm.rating.beta", 1F / actionSize); + + userTopicTimes = DenseMatrix.valueOf(userSize, numberOfUserTopics); + itemTopicTimes = DenseMatrix.valueOf(itemSize, numberOfItemTopics); + userScoreTimes = DenseVector.valueOf(userSize); + itemScoreTimes = DenseVector.valueOf(itemSize); + + rateTopicTimes = new int[numberOfUserTopics][numberOfItemTopics][actionSize]; + topicTimes = DenseMatrix.valueOf(numberOfUserTopics, numberOfItemTopics); + topicProbabilities = DenseMatrix.valueOf(numberOfUserTopics, numberOfItemTopics); + userProbabilities = DenseVector.valueOf(numberOfUserTopics); + itemProbabilities = DenseVector.valueOf(numberOfItemTopics); + + userTopics = new Int2IntRBTreeMap(); + itemTopics = new Int2IntRBTreeMap(); + + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + int scoreIndex = scoreIndexes.get(score); + + int userTopic = RandomUtility.randomInteger(numberOfUserTopics); + int itemTopic = RandomUtility.randomInteger(numberOfItemTopics); + + userTopicTimes.shiftValue(userIndex, userTopic, 1); + userScoreTimes.shiftValue(userIndex, 1); + + itemTopicTimes.shiftValue(itemIndex, itemTopic, 1); + itemScoreTimes.shiftValue(itemIndex, 1); + + rateTopicTimes[userTopic][itemTopic][scoreIndex]++; + topicTimes.shiftValue(userTopic, itemTopic, 1); + + userTopics.put(userIndex * itemSize + itemIndex, userTopic); + itemTopics.put(userIndex * itemSize + itemIndex, itemTopic); + } + + // parameters + userTopicSums = DenseMatrix.valueOf(userSize, numberOfUserTopics); + itemTopicSums = DenseMatrix.valueOf(itemSize, numberOfItemTopics); + rateTopicProbabilities = new float[numberOfUserTopics][numberOfItemTopics][actionSize]; + rateTopicSums = new float[numberOfUserTopics][numberOfItemTopics][actionSize]; + } + + @Override + protected void eStep() { + // 缓存概率 + float random = 0F; + + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + // TODO 此处可以重构 + int scoreIndex = scoreIndexes.get(score); + // TODO 此处可以重构 + // user and item's factors + int userTopic = userTopics.get(userIndex * itemSize + itemIndex); + int itemTopic = itemTopics.get(userIndex * itemSize + itemIndex); + + // remove this observation + userTopicTimes.shiftValue(userIndex, userTopic, -1); + userScoreTimes.shiftValue(userIndex, -1); + + itemTopicTimes.shiftValue(itemIndex, itemTopic, -1); + itemScoreTimes.shiftValue(itemIndex, -1); + + rateTopicTimes[userTopic][itemTopic][scoreIndex]--; + topicTimes.shiftValue(userTopic, itemTopic, -1); + + int topicIndex = userTopic; + // TODO + // 此处topicProbabilities似乎可以与userProbabilities和itemProbabilities整合. + // Compute P(i, j) + // 归一化 + topicProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + // Compute Pmn + float v1 = (userTopicTimes.getValue(userIndex, row) + userAlpha) / (userScoreTimes.getValue(userIndex) + numberOfUserTopics * userAlpha); + float v2 = (userTopicTimes.getValue(topicIndex, column) + itemAlpha) / (itemScoreTimes.getValue(itemIndex) + numberOfItemTopics * itemAlpha); + float v3 = (rateTopicTimes[row][column][scoreIndex] + ratingBeta) / (topicTimes.getValue(row, column) + actionSize * ratingBeta); + float value = v1 * v2 * v3; + scalar.setValue(value); + }); + // Re-sample user factor + // 计算概率 + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + userProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = topicProbabilities.getRowVector(index).getSum(false); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + userTopic = SampleUtility.binarySearch(userProbabilities, 0, userProbabilities.getElementSize() - 1, RandomUtility.randomFloat(sum.getValue())); + sum.setValue(0F); + itemProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = topicProbabilities.getColumnVector(index).getSum(false); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + itemTopic = SampleUtility.binarySearch(itemProbabilities, 0, itemProbabilities.getElementSize() - 1, RandomUtility.randomFloat(sum.getValue())); + + // Add statistics + userTopicTimes.shiftValue(userIndex, userTopic, 1); + userScoreTimes.shiftValue(userIndex, 1); + + itemTopicTimes.shiftValue(itemIndex, itemTopic, 1); + itemScoreTimes.shiftValue(itemIndex, 1); + + rateTopicTimes[userTopic][itemTopic][scoreIndex]++; + topicTimes.shiftValue(userTopic, itemTopic, 1); + + userTopics.put(userIndex * itemSize + itemIndex, userTopic); + itemTopics.put(userIndex * itemSize + itemIndex, itemTopic); + } + } + + @Override + protected void mStep() { + // TODO Auto-generated method stub + + } + + @Override + protected void readoutParameters() { + for (int userIndex = 0; userIndex < userSize; userIndex++) { + for (int topicIndex = 0; topicIndex < numberOfUserTopics; topicIndex++) { + userTopicSums.shiftValue(userIndex, topicIndex, (userTopicTimes.getValue(userIndex, topicIndex) + userAlpha) / (userScoreTimes.getValue(userIndex) + numberOfUserTopics * userAlpha)); + } + } + + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + for (int topicIndex = 0; topicIndex < numberOfItemTopics; topicIndex++) { + itemTopicSums.shiftValue(itemIndex, topicIndex, (itemTopicTimes.getValue(itemIndex, topicIndex) + itemAlpha) / (itemScoreTimes.getValue(itemIndex) + numberOfItemTopics * itemAlpha)); + } + } + + for (int userTopic = 0; userTopic < numberOfUserTopics; userTopic++) { + for (int itemTopic = 0; itemTopic < numberOfItemTopics; itemTopic++) { + for (int scoreIndex = 0; scoreIndex < actionSize; scoreIndex++) { + rateTopicSums[userTopic][itemTopic][scoreIndex] += (rateTopicTimes[userTopic][itemTopic][scoreIndex] + ratingBeta) / (topicTimes.getValue(userTopic, itemTopic) + actionSize * ratingBeta); + } + } + } + numberOfStatistics++; + } + + /** + * estimate the model parameters + */ + @Override + protected void estimateParameters() { + float scale = 1F / numberOfStatistics; + // TODO + // 此处可以重构(整合userTopicProbabilities/userTopicSums和itemTopicProbabilities/itemTopicSums) + userTopicProbabilities = DenseMatrix.copyOf(userTopicSums); + userTopicProbabilities.scaleValues(scale); + itemTopicProbabilities = DenseMatrix.copyOf(itemTopicSums); + itemTopicProbabilities.scaleValues(scale); + + // TODO 此处可以重构(整合rateTopicProbabilities/rateTopicSums) + for (int userTopic = 0; userTopic < numberOfUserTopics; userTopic++) { + for (int itemTopic = 0; itemTopic < numberOfItemTopics; itemTopic++) { + for (int scoreIndex = 0; scoreIndex < actionSize; scoreIndex++) { + rateTopicProbabilities[userTopic][itemTopic][scoreIndex] = rateTopicSums[userTopic][itemTopic][scoreIndex] / numberOfStatistics; + } + } + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = 0F; + for (Entry term : scoreIndexes.entrySet()) { + float score = term.getKey(); + int scoreIndex = term.getValue(); + float probability = 0F; // P(r|u,v)=\sum_{i,j} P(r|i,j)P(i|u)P(j|v) + for (int userTopic = 0; userTopic < numberOfUserTopics; userTopic++) { + for (int itemTopic = 0; itemTopic < numberOfItemTopics; itemTopic++) { + probability += rateTopicProbabilities[userTopic][itemTopic][scoreIndex] * userTopicProbabilities.getValue(userIndex, userTopic) * itemTopicProbabilities.getValue(itemIndex, itemTopic); + } + } + value += score * probability; + } + instance.setQuantityMark(value); + } + + @Override + protected boolean isConverged(int iter) { + // Get the parameters + estimateParameters(); + // Compute the perplexity + float sum = 0F; + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + sum += perplexity(userIndex, itemIndex, score); + } + float perplexity = (float) Math.exp(sum / actionSize); + float delta = perplexity - currentError; + if (numberOfStatistics > 1 && delta > 0) { + return true; + } + currentError = perplexity; + return false; + } + + private double perplexity(int user, int item, double score) { + int scoreIndex = (int) (score / minimumScore - 1); + // Compute P(r | u, v) + double probability = 0; + for (int userTopic = 0; userTopic < numberOfUserTopics; userTopic++) { + for (int itemTopic = 0; itemTopic < numberOfItemTopics; itemTopic++) { + probability += rateTopicProbabilities[userTopic][itemTopic][scoreIndex] * userTopicProbabilities.getValue(user, userTopic) * itemTopicProbabilities.getValue(item, itemTopic); + } + } + return -Math.log(probability); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/LLORMALearner.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/LLORMALearner.java new file mode 100644 index 0000000..0b46c9e --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/LLORMALearner.java @@ -0,0 +1,161 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; + +/** + * + * LLORMA学习器 + * + *
+ * Local Low-Rank Matrix Approximation
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class LLORMALearner extends Thread { + /** + * The unique identifier of the thread. + */ + private int threadId; + + /** + * The number of features. + */ + private int numberOfFactors; + + /** + * Learning rate parameter. + */ + private float learnRatio; + + /** + * The maximum number of iteration. + */ + private int localIteration; + + /** + * Regularization factor parameter. + */ + private float userRegularization, itemRegularization; + + /** + * User profile in low-rank matrix form. + */ + private DenseMatrix userFactors; + + /** + * Item profile in low-rank matrix form. + */ + private DenseMatrix itemFactors; + + /** + * The vector containing each user's weight. + */ + private DenseVector userWeights; + + /** + * The vector containing each item's weight. + */ + private DenseVector itemWeights; + + /** + * The rating matrix used for learning. + */ + private SparseMatrix trainMatrix; + + /** + * Construct a local model for singleton LLORMA. + * + * @param threadId A unique thread ID. + * @param numberOfFactors The rank which will be used in this local model. + * @param numUsersParam The number of users. + * @param numItemsParam The number of items. + * @param anchorUserParam The anchor user used to learn this local model. + * @param anchorItemParam The anchor item used to learn this local model. + * @param learnRatio Learning rate parameter. + * @param userWeights Initial vector containing each user's weight. + * @param itemWeights Initial vector containing each item's weight. + * @param trainMatrix The rating matrix used for learning. + * @param localIteration localIterationParam + * @param itemRegularization localRegItemParam + * @param userRegularization localRegUserParam + */ + public LLORMALearner(int threadId, int numberOfFactors, float learnRatio, float userRegularization, float itemRegularization, int localIteration, DenseMatrix userFactors, DenseMatrix itemFactors, DenseVector userWeights, DenseVector itemWeights, SparseMatrix trainMatrix) { + this.threadId = threadId; + this.numberOfFactors = numberOfFactors; + this.learnRatio = learnRatio; + this.userRegularization = userRegularization; + this.itemRegularization = itemRegularization; + this.localIteration = localIteration; + this.userWeights = userWeights; + this.itemWeights = itemWeights; + this.userFactors = userFactors; + this.itemFactors = itemFactors; + this.trainMatrix = trainMatrix; + } + + public int getIndex() { + return threadId; + } + + /** + * Getter method for user profile of this local model. + * + * @return The user profile of this local model. + */ + public DenseMatrix getUserFactors() { + return userFactors; + } + + /** + * Getter method for item profile of this local model. + * + * @return The item profile of this local model. + */ + public DenseMatrix getItemFactors() { + return itemFactors; + } + + /** + * Learn this local model based on similar users to the anchor user and similar + * items to the anchor item. Implemented with gradient descent. + */ + @Override + public void run() { + // Learn by Weighted RegSVD + for (int iterationStep = 0; iterationStep < localIteration; iterationStep++) { + for (MatrixScalar term : trainMatrix) { + int userIndex = term.getRow(); // user + int itemIndex = term.getColumn(); // item + float score = term.getValue(); + + float predict = predict(userIndex, itemIndex); + float error = score - predict; + float weight = userWeights.getValue(userIndex) * itemWeights.getValue(itemIndex); + + // update factors + for (int factorIndex = 0; factorIndex < numberOfFactors; factorIndex++) { + float userFactorValue = userFactors.getValue(userIndex, factorIndex); + float itemFactorValue = itemFactors.getValue(itemIndex, factorIndex); + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (error * itemFactorValue * weight - userRegularization * userFactorValue)); + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * (error * userFactorValue * weight - itemRegularization * itemFactorValue)); + } + } + } + } + + private float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userFactors.getRowVector(userIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + float value = scalar.dotProduct(userVector, itemVector).getValue(); + return value; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/LLORMAModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/LLORMAModel.java new file mode 100644 index 0000000..2c70687 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/LLORMAModel.java @@ -0,0 +1,270 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * LLORMA推荐器 + * + *
+ * Local Low-Rank Matrix Approximation
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class LLORMAModel extends MatrixFactorizationModel { + private int numberOfGlobalFactors, numberOfLocalFactors; + private int globalEpocheSize, localEpocheSize; + private int numberOfThreads; + private float globalUserRegularization, globalItemRegularization, localUserRegularization, localItemRegularization; + private float globalLearnRatio, localLearnRatio; + + private int numberOfModels; + private DenseMatrix globalUserFactors, globalItemFactors; + + private DenseMatrix[] userMatrixes; + + private DenseMatrix[] itemMatrixes; + + private int[] anchorUsers; + private int[] anchorItems; + + /* + * (non-Javadoc) + * + * @see net.librecommender.recommender.AbstractRecommender#setup() + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + numberOfGlobalFactors = configuration.getInteger("recommender.global.factors.num", 20); + numberOfLocalFactors = factorSize; + + globalEpocheSize = configuration.getInteger("recommender.global.iteration.maximum", 100); + localEpocheSize = epocheSize; + + globalUserRegularization = configuration.getFloat("recommender.global.user.regularization", 0.01F); + globalItemRegularization = configuration.getFloat("recommender.global.item.regularization", 0.01F); + localUserRegularization = userRegularization; + localItemRegularization = itemRegularization; + + globalLearnRatio = configuration.getFloat("recommender.global.iteration.learnrate", 0.01F); + localLearnRatio = configuration.getFloat("recommender.iteration.learnrate", 0.01F); + + numberOfThreads = configuration.getInteger("recommender.thread.count", 4); + numberOfModels = configuration.getInteger("recommender.model.num", 50); + + numberOfThreads = numberOfThreads > numberOfModels ? numberOfModels : numberOfThreads; + + // global svd P Q to calculate the kernel value between users (or items) + globalUserFactors = DenseMatrix.valueOf(userSize, numberOfGlobalFactors); + globalUserFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + globalItemFactors = DenseMatrix.valueOf(itemSize, numberOfGlobalFactors); + globalItemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + } + + // global svd P Q + private void practiceGlobalModel(DefaultScalar scalar) { + for (int epocheIndex = 0; epocheIndex < globalEpocheSize; epocheIndex++) { + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); // user + int itemIndex = term.getColumn(); // item + float score = term.getValue(); + + // TODO 考虑重构,减少userVector与itemVector的重复构建 + DenseVector userVector = globalUserFactors.getRowVector(userIndex); + DenseVector itemVector = globalItemFactors.getRowVector(itemIndex); + float predict = scalar.dotProduct(userVector, itemVector).getValue(); + float error = score - predict; + + // update factors + for (int factorIndex = 0; factorIndex < numberOfGlobalFactors; factorIndex++) { + float userFactor = globalUserFactors.getValue(userIndex, factorIndex); + float itemFactor = globalItemFactors.getValue(itemIndex, factorIndex); + globalUserFactors.shiftValue(userIndex, factorIndex, globalLearnRatio * (error * itemFactor - globalUserRegularization * userFactor)); + globalItemFactors.shiftValue(itemIndex, factorIndex, globalLearnRatio * (error * userFactor - globalItemRegularization * itemFactor)); + } + } + } + + userMatrixes = new DenseMatrix[numberOfModels]; + itemMatrixes = new DenseMatrix[numberOfModels]; + anchorUsers = new int[numberOfModels]; + anchorItems = new int[numberOfModels]; + // end of training + } + + /** + * Calculate similarity between two users, based on the global base SVD. + * + * @param leftUserIndex The first user's ID. + * @param rightUserIndex The second user's ID. + * @return The similarity value between two users idx1 and idx2. + */ + private float getUserSimilarity(DefaultScalar scalar, int leftUserIndex, int rightUserIndex) { + float similarity; + // TODO 减少向量的重复构建 + DenseVector leftUserVector = globalUserFactors.getRowVector(leftUserIndex); + DenseVector rightUserVector = globalUserFactors.getRowVector(rightUserIndex); + similarity = (float) (1 - 2F / Math.PI * Math.acos(scalar.dotProduct(leftUserVector, rightUserVector).getValue() / (Math.sqrt(scalar.dotProduct(leftUserVector, leftUserVector).getValue()) * Math.sqrt(scalar.dotProduct(rightUserVector, rightUserVector).getValue())))); + if (Float.isNaN(similarity)) { + similarity = 0F; + } + return similarity; + } + + /** + * Calculate similarity between two items, based on the global base SVD. + * + * @param leftItemIndex The first item's ID. + * @param rightItemIndex The second item's ID. + * @return The similarity value between two items idx1 and idx2. + */ + private float getItemSimilarity(DefaultScalar scalar, int leftItemIndex, int rightItemIndex) { + float similarity; + // TODO 减少向量的重复构建 + DenseVector leftItemVector = globalItemFactors.getRowVector(leftItemIndex); + DenseVector rightItemVector = globalItemFactors.getRowVector(rightItemIndex); + similarity = (float) (1 - 2D / Math.PI * Math.acos(scalar.dotProduct(leftItemVector, rightItemVector).getValue() / (Math.sqrt(scalar.dotProduct(leftItemVector, leftItemVector).getValue()) * Math.sqrt(scalar.dotProduct(rightItemVector, rightItemVector).getValue())))); + if (Float.isNaN(similarity)) { + similarity = 0F; + } + return similarity; + } + + /** + * Given the similarity, it applies the given kernel. This is done either for + * all users or for all items. + * + * @param size The length of user or item vector. + * @param anchorIdx The identifier of anchor point. + * @param type The type of kernel. + * @param width Kernel width. + * @param isItemFeature return item kernel if yes, return user kernel otherwise. + * @return The kernel-smoothed values for all users or all items. + */ + private DenseVector kernelSmoothing(DefaultScalar scalar, int size, int anchorIdx, KernelSmoother type, float width, boolean isItemFeature) { + DenseVector featureVector = DenseVector.valueOf(size); + // TODO 此处似乎有Bug? + featureVector.setValue(anchorIdx, 1F); + for (int index = 0; index < size; index++) { + float similarity; + if (isItemFeature) { + similarity = getItemSimilarity(scalar, index, anchorIdx); + } else { // userFeature + similarity = getUserSimilarity(scalar, index, anchorIdx); + } + featureVector.setValue(index, type.kernelize(similarity, width)); + } + return featureVector; + } + + private void practiceLocalModels(DefaultScalar scalar) { + // Pre-calculating similarity: + int completeModelCount = 0; + + // TODO 此处的变量与矩阵可以整合到LLORMALearner,LLORMALearner变成任务. + LLORMALearner[] learners = new LLORMALearner[numberOfThreads]; + + int modelCount = 0; + int[] runningThreadList = new int[numberOfThreads]; + int runningThreadCount = 0; + int waitingThreadPointer = 0; + int nextRunningSlot = 0; + + // Parallel training: + while (completeModelCount < numberOfModels) { + int randomUserIndex = RandomUtility.randomInteger(userSize); + // TODO 考虑重构 + SparseVector userVector = scoreMatrix.getRowVector(randomUserIndex); + if (userVector.getElementSize() == 0) { + continue; + } + // TODO 此处的并发模型有问题,需要重构.否则当第一次runningThreadCount >= + // numThreads之后,都是单线程执行. + if (runningThreadCount < numberOfThreads && modelCount < numberOfModels) { + // Selecting a new anchor point: + int randomItemIndex = userVector.getIndex(RandomUtility.randomInteger(userVector.getElementSize())); + anchorUsers[modelCount] = randomUserIndex; + anchorItems[modelCount] = randomItemIndex; + // Preparing weight vectors: + DenseVector userWeights = kernelSmoothing(scalar, userSize, randomUserIndex, KernelSmoother.EPANECHNIKOV_KERNEL, 0.8F, false); + DenseVector itemWeights = kernelSmoothing(scalar, itemSize, randomItemIndex, KernelSmoother.EPANECHNIKOV_KERNEL, 0.8F, true); + DenseMatrix localUserFactors = DenseMatrix.valueOf(userSize, numberOfLocalFactors); + localUserFactors.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(distribution.sample().floatValue()); + }); + DenseMatrix localItemFactors = DenseMatrix.valueOf(itemSize, numberOfLocalFactors); + localItemFactors.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(distribution.sample().floatValue()); + }); + // Starting a new local model learning: + learners[nextRunningSlot] = new LLORMALearner(modelCount, numberOfLocalFactors, localLearnRatio, localUserRegularization, localItemRegularization, localEpocheSize, localUserFactors, localItemFactors, userWeights, itemWeights, scoreMatrix); + learners[nextRunningSlot].start(); + runningThreadList[runningThreadCount] = modelCount; + runningThreadCount++; + modelCount++; + nextRunningSlot++; + } else if (runningThreadCount > 0) { + // Joining a local model which was done with learning: + try { + learners[waitingThreadPointer].join(); + } catch (InterruptedException ie) { + logger.error("Join failed: " + ie); + } + LLORMALearner learner = learners[waitingThreadPointer]; + userMatrixes[learner.getIndex()] = learner.getUserFactors(); + itemMatrixes[learner.getIndex()] = learner.getItemFactors(); + nextRunningSlot = waitingThreadPointer; + waitingThreadPointer = (waitingThreadPointer + 1) % numberOfThreads; + runningThreadCount--; + completeModelCount++; + } + } + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + practiceGlobalModel(scalar); + practiceLocalModels(scalar); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + float weightSum = 0F; + float valueSum = 0F; + for (int iterationStep = 0; iterationStep < numberOfModels; iterationStep++) { + float weight = KernelSmoother.EPANECHNIKOV_KERNEL.kernelize(getUserSimilarity(scalar, anchorUsers[iterationStep], userIndex), 0.8F) * KernelSmoother.EPANECHNIKOV_KERNEL.kernelize(getItemSimilarity(scalar, anchorItems[iterationStep], itemIndex), 0.8F); + float value = (scalar.dotProduct(userMatrixes[iterationStep].getRowVector(userIndex), itemMatrixes[iterationStep].getRowVector(itemIndex)).getValue()) * weight; + weightSum += weight; + valueSum += value; + } + float score = valueSum / weightSum; + if (Float.isNaN(score) || score == 0F) { + score = meanScore; + } + instance.setQuantityMark(score); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/MFALSModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/MFALSModel.java new file mode 100644 index 0000000..79cec1b --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/MFALSModel.java @@ -0,0 +1,103 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.math.MatrixUtility; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * MF ALS推荐器 + * + *
+ * Large-Scale Parallel Collaborative Filtering for the Netflix Prize
+ * http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class MFALSModel extends MatrixFactorizationModel { + + @Override + protected void doPractice() { + DenseVector scoreVector = DenseVector.valueOf(factorSize); + DenseMatrix inverseMatrix = DenseMatrix.valueOf(factorSize, factorSize); + DenseMatrix transposeMatrix = DenseMatrix.valueOf(factorSize, factorSize); + DenseMatrix copyMatrix = DenseMatrix.valueOf(factorSize, factorSize); + // TODO 可以考虑只获取有评分的用户? + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + // fix item matrix M, solve user matrix U + for (int userIndex = 0; userIndex < userSize; userIndex++) { + // number of items rated by user userIdx + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + int size = userVector.getElementSize(); + if (size == 0) { + continue; + } + // TODO 此处应该避免valueOf + DenseMatrix rateMatrix = DenseMatrix.valueOf(size, factorSize); + DenseVector rateVector = DenseVector.valueOf(size); + int index = 0; + for (VectorScalar term : userVector) { + // step 1: + int itemIndex = term.getIndex(); + rateMatrix.getRowVector(index).copyVector(itemFactors.getRowVector(itemIndex)); + + // step 2: + // ratings of this userIdx + rateVector.setValue(index++, term.getValue()); + } + + // step 3: the updated user matrix wrt user j + DenseMatrix matrix = transposeMatrix; + matrix.dotProduct(rateMatrix, true, rateMatrix, false, MathCalculator.SERIAL); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + matrix.shiftValue(factorIndex, factorIndex, userRegularization * size); + } + scoreVector.dotProduct(rateMatrix, true, rateVector, MathCalculator.SERIAL); + userFactors.getRowVector(userIndex).dotProduct(MatrixUtility.inverse(matrix, copyMatrix, inverseMatrix), false, scoreVector, MathCalculator.SERIAL); + } + + // TODO 可以考虑只获取有评分的条目? + // fix user matrix U, solve item matrix M + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + // latent factor of users that have rated item itemIdx + // number of users rate item j + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + int size = itemVector.getElementSize(); + if (size == 0) { + continue; + } + + // TODO 此处应该避免valueOf + DenseMatrix rateMatrix = DenseMatrix.valueOf(size, factorSize); + DenseVector rateVector = DenseVector.valueOf(size); + int index = 0; + for (VectorScalar term : itemVector) { + // step 1: + int userIndex = term.getIndex(); + rateMatrix.getRowVector(index).copyVector(userFactors.getRowVector(userIndex)); + + // step 2: + // ratings of this item + rateVector.setValue(index++, term.getValue()); + } + + // step 3: the updated item matrix wrt item j + DenseMatrix matrix = transposeMatrix; + matrix.dotProduct(rateMatrix, true, rateMatrix, false, MathCalculator.SERIAL); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + matrix.shiftValue(factorIndex, factorIndex, itemRegularization * size); + } + scoreVector.dotProduct(rateMatrix, true, rateVector, MathCalculator.SERIAL); + itemFactors.getRowVector(itemIndex).dotProduct(MatrixUtility.inverse(matrix, copyMatrix, inverseMatrix), false, scoreVector, MathCalculator.SERIAL); + } + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/NMFModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/NMFModel.java new file mode 100644 index 0000000..b572d7f --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/NMFModel.java @@ -0,0 +1,105 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.MathUtility; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.ArrayVector; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * NMF推荐器 + * + *
+ * Algorithms for Non-negative Matrix Factorization
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class NMFModel extends MatrixFactorizationModel { + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + userFactors = DenseMatrix.valueOf(userSize, factorSize); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(0.01F)); + }); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(0.01F)); + }); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + for (int epocheIndex = 0; epocheIndex < epocheSize; ++epocheIndex) { + // update userFactors by fixing itemFactors + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0) { + continue; + } + int user = userIndex; + ArrayVector predictVector = new ArrayVector(userVector); + predictVector.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(predict(user, element.getIndex())); + }); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + DenseVector factorVector = itemFactors.getColumnVector(factorIndex); + float score = scalar.dotProduct(factorVector, userVector).getValue(); + float predict = scalar.dotProduct(factorVector, predictVector).getValue() + MathUtility.EPSILON; + userFactors.setValue(userIndex, factorIndex, userFactors.getValue(userIndex, factorIndex) * (score / predict)); + } + } + + // update itemFactors by fixing userFactors + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + if (itemVector.getElementSize() == 0) { + continue; + } + int item = itemIndex; + ArrayVector predictVector = new ArrayVector(itemVector); + predictVector.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(predict(element.getIndex(), item)); + }); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + DenseVector factorVector = userFactors.getColumnVector(factorIndex); + float score = scalar.dotProduct(factorVector, itemVector).getValue(); + float predict = scalar.dotProduct(factorVector, predictVector).getValue() + MathUtility.EPSILON; + itemFactors.setValue(itemIndex, factorIndex, itemFactors.getValue(itemIndex, factorIndex) * (score / predict)); + } + } + + // compute errors + totalError = 0F; + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + if (score > 0) { + float error = predict(userIndex, itemIndex) - score; + totalError += error * error; + } + } + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/PMFModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/PMFModel.java new file mode 100644 index 0000000..325f957 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/PMFModel.java @@ -0,0 +1,50 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * PMF推荐器 + * + *
+ * PMF: Probabilistic Matrix Factorization
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class PMFModel extends MatrixFactorizationModel { + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); // user + int itemIndex = term.getColumn(); // item + float score = term.getValue(); + float predict = predict(userIndex, itemIndex); + float error = score - predict; + totalError += error * error; + + // update factors + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex), itemFactor = itemFactors.getValue(itemIndex, factorIndex); + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (error * itemFactor - userRegularization * userFactor)); + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * (error * userFactor - itemRegularization * itemFactor)); + totalError += userRegularization * userFactor * userFactor + itemRegularization * itemFactor * itemFactor; + } + } + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/RBMModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/RBMModel.java new file mode 100644 index 0000000..572e5b5 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/RBMModel.java @@ -0,0 +1,413 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Map.Entry; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.random.JDKRandomGenerator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.algorithm.probability.QuantityProbability; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; + +/** + * + * RBM推荐器 + * + *
+ * Restricted Boltzman Machines for Collaborative Filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class RBMModel extends ProbabilisticGraphicalModel { + + private int steps; + private float epsilonWeight; + private float epsilonExplicitBias; + private float epsilonImplicitBias; + private float momentum; + private float lamtaWeight; + private float lamtaBias; + + private float[][][] weightSums; + private float[][][] weightProbabilities; + private float[][] explicitBiasSums; + private float[][] explicitBiasProbabilities; + private float[] implicitBiasSums; + private float[] implicitBiasProbabilities; + + private float[][][] positiveWeights; + private float[][][] negativeWeights; + + private float[][] positiveExplicitActs; + private float[][] negativeExplicitActs; + + private float[] positiveImplicitActs; + private float[] negativeImplicitActs; + + private int[] itemCount; + private PredictionType predictionType; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // TODO 此处可以重构 + epocheSize = configuration.getInteger("recommender.iterator.maximum", 10); + sampleSize = configuration.getInteger("recommender.sample.mumber", 100); + scoreSize = scoreIndexes.size() + 1; + factorSize = configuration.getInteger("recommender.factor.number", 500); + + epsilonWeight = configuration.getFloat("recommender.epsilonw", 0.001F); + epsilonExplicitBias = configuration.getFloat("recommender.epsilonvb", 0.001F); + epsilonImplicitBias = configuration.getFloat("recommender.epsilonhb", 0.001F); + steps = configuration.getInteger("recommender.tstep", 1); + momentum = configuration.getFloat("recommender.momentum", 0F); + lamtaWeight = configuration.getFloat("recommender.lamtaw", 0.001F); + lamtaBias = configuration.getFloat("recommender.lamtab", 0F); + predictionType = PredictionType.valueOf(configuration.getString("recommender.predictiontype", "mean").toUpperCase()); + + weightProbabilities = new float[itemSize][scoreSize][factorSize]; + explicitBiasProbabilities = new float[itemSize][scoreSize]; + implicitBiasProbabilities = new float[factorSize]; + + weightSums = new float[itemSize][scoreSize][factorSize]; + implicitBiasSums = new float[factorSize]; + explicitBiasSums = new float[itemSize][scoreSize]; + + positiveWeights = new float[itemSize][scoreSize][factorSize]; + negativeWeights = new float[itemSize][scoreSize][factorSize]; + + positiveImplicitActs = new float[factorSize]; + negativeImplicitActs = new float[factorSize]; + + positiveExplicitActs = new float[itemSize][scoreSize]; + negativeExplicitActs = new float[itemSize][scoreSize]; + + itemCount = new int[itemSize]; + + // TODO 此处需要重构 + int[][] itemScoreCount = new int[itemSize][scoreSize]; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0) { + continue; + } + for (VectorScalar term : userVector) { + int scoreIndex = scoreIndexes.get(term.getValue()); + itemScoreCount[term.getIndex()][scoreIndex]++; + } + } + QuantityProbability distribution = new QuantityProbability(JDKRandomGenerator.class, 0, NormalDistribution.class, 0D, 0.01D); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + weightProbabilities[itemIndex][scoreIndex][factorIndex] = distribution.sample().floatValue(); + } + } + } + + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + double totalScore = 0D; + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + totalScore += itemScoreCount[itemIndex][scoreIndex]; + } + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + if (totalScore == 0D) { + explicitBiasProbabilities[itemIndex][scoreIndex] = RandomUtility.randomFloat(0.001F); + } else { + explicitBiasProbabilities[itemIndex][scoreIndex] = (float) Math.log(itemScoreCount[itemIndex][scoreIndex] / totalScore); + // visbiases[i][k] = Math.log(((moviecount[i][k]) + 1) / + // (trainMatrix.columnSize(i)+ softmax)); + } + } + } + } + + @Override + protected void doPractice() { + Collection currentImplicitStates; + Collection positiveImplicitStates = new ArrayList<>(factorSize); + Collection negativeImplicitStates = new ArrayList<>(factorSize); + DenseVector negativeExplicitProbabilities = DenseVector.valueOf(scoreSize); + int[] negativeExplicitScores = new int[itemSize]; + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + reset(); + // 随机遍历顺序 + Integer[] userIndexes = new Integer[userSize]; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + userIndexes[userIndex] = userIndex; + } + RandomUtility.shuffle(userIndexes); + for (int userIndex : userIndexes) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0) { + continue; + } + DenseVector factorSum = DenseVector.valueOf(factorSize); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + int scoreIndex = scoreIndexes.get(term.getValue()); + itemCount[itemIndex]++; + positiveExplicitActs[itemIndex][scoreIndex] += 1F; + factorSum.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value + weightProbabilities[itemIndex][scoreIndex][index]); + }); + } + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float probability = (float) (1F / (1F + Math.exp(-factorSum.getValue(factorIndex) - implicitBiasProbabilities[factorIndex]))); + if (probability > RandomUtility.randomFloat(1F)) { + positiveImplicitStates.add(factorIndex); + positiveImplicitActs[factorIndex] += 1F; + } + } + + currentImplicitStates = positiveImplicitStates; + + int step = 0; + + do { + boolean isLast = (step + 1 >= steps); + for (VectorScalar term : userVector) { + negativeExplicitProbabilities.setValues(0F); + int itemIndex = term.getIndex(); + for (int factorIndex : currentImplicitStates) { + negativeExplicitProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value + weightProbabilities[itemIndex][index][factorIndex]); + }); + } + + // 归一化 + negativeExplicitProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + value = (float) (1F / (1F + Math.exp(-value - explicitBiasProbabilities[itemIndex][index]))); + scalar.setValue(value); + }); + negativeExplicitProbabilities.scaleValues(1F / negativeExplicitProbabilities.getSum(false)); + + // TODO 此处随机概率落在某个分段(需要重构,否则永远最多落在5个分段,应该是Bug.) + float random = RandomUtility.randomFloat(1F); + negativeExplicitScores[itemIndex] = scoreSize - 1; + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + if ((random -= negativeExplicitProbabilities.getValue(scoreIndex)) <= 0F) { + negativeExplicitScores[itemIndex] = scoreIndex; + break; + } + } + if (isLast) { + negativeExplicitActs[itemIndex][negativeExplicitScores[itemIndex]] += 1F; + } + } + + factorSum.setValues(0F); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + factorSum.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value + weightProbabilities[itemIndex][negativeExplicitScores[itemIndex]][index]); + }); + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float probability = (float) (1F / (1F + Math.exp(-factorSum.getValue(factorIndex) - implicitBiasProbabilities[factorIndex]))); + if (probability > RandomUtility.randomFloat(1F)) { + negativeImplicitStates.add(factorIndex); + if (isLast) { + negativeImplicitActs[factorIndex] += 1.0; + } + } + } + + if (!isLast) { + currentImplicitStates = negativeImplicitStates; + } + } while (++step < steps); + + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + int scoreIndex = scoreIndexes.get(term.getValue()); + for (int factorIndex : positiveImplicitStates) { + positiveWeights[itemIndex][scoreIndex][factorIndex] += 1D; + } + for (int factorIndex : negativeImplicitStates) { + negativeWeights[itemIndex][negativeExplicitScores[itemIndex]][factorIndex] += 1D; + } + } + + positiveImplicitStates.clear(); + negativeImplicitStates.clear(); + update(userIndex); + } + } + + } + + private void update(int userIndex) { + // TODO size是否应该由参数指定? + if (((userIndex + 1) % sampleSize) == 0 || (userIndex + 1) == userSize) { + int numCases = userIndex % sampleSize; + numCases++; + + float positiveExplicitAct; + float negativeExplicitAct; + float positiveImplicitAct; + float negativeImplicitAct; + float positiveWeight; + float negativeWeight; + + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + if (itemCount[itemIndex] == 0) { + continue; + } + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + positiveExplicitAct = positiveExplicitActs[itemIndex][scoreIndex]; + negativeExplicitAct = negativeExplicitActs[itemIndex][scoreIndex]; + if (positiveExplicitAct != 0F || negativeExplicitAct != 0F) { + positiveExplicitAct /= itemCount[itemIndex]; + negativeExplicitAct /= itemCount[itemIndex]; + explicitBiasSums[itemIndex][scoreIndex] = momentum * explicitBiasSums[itemIndex][scoreIndex] + epsilonExplicitBias * (positiveExplicitAct - negativeExplicitAct - lamtaBias * explicitBiasProbabilities[itemIndex][scoreIndex]); + explicitBiasProbabilities[itemIndex][scoreIndex] += explicitBiasSums[itemIndex][scoreIndex]; + positiveExplicitActs[itemIndex][scoreIndex] = 0F; + negativeExplicitActs[itemIndex][scoreIndex] = 0F; + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + positiveWeight = positiveWeights[itemIndex][scoreIndex][factorIndex]; + negativeWeight = negativeWeights[itemIndex][scoreIndex][factorIndex]; + if (positiveWeight != 0F || negativeWeight != 0F) { + positiveWeight /= itemCount[itemIndex]; + negativeWeight /= itemCount[itemIndex]; + weightSums[itemIndex][scoreIndex][factorIndex] = momentum * weightSums[itemIndex][scoreIndex][factorIndex] + epsilonWeight * ((positiveWeight - negativeWeight) - lamtaWeight * weightProbabilities[itemIndex][scoreIndex][factorIndex]); + weightProbabilities[itemIndex][scoreIndex][factorIndex] += weightSums[itemIndex][scoreIndex][factorIndex]; + positiveWeights[itemIndex][scoreIndex][factorIndex] = 0F; + negativeWeights[itemIndex][scoreIndex][factorIndex] = 0F; + } + } + } + itemCount[itemIndex] = 0; + } + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + positiveImplicitAct = positiveImplicitActs[factorIndex]; + negativeImplicitAct = negativeImplicitActs[factorIndex]; + if (positiveImplicitAct != 0F || negativeImplicitAct != 0F) { + positiveImplicitAct /= numCases; + negativeImplicitAct /= numCases; + implicitBiasSums[factorIndex] = momentum * implicitBiasSums[factorIndex] + epsilonImplicitBias * (positiveImplicitAct - negativeImplicitAct - lamtaBias * implicitBiasProbabilities[factorIndex]); + implicitBiasProbabilities[factorIndex] += implicitBiasSums[factorIndex]; + positiveImplicitActs[factorIndex] = 0F; + negativeImplicitActs[factorIndex] = 0F; + } + } + } + + } + + private void reset() { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + itemCount[itemIndex] = 0; + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + positiveExplicitActs[itemIndex][scoreIndex] = 0F; + negativeExplicitActs[itemIndex][scoreIndex] = 0F; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + positiveWeights[itemIndex][scoreIndex][factorIndex] = 0F; + negativeWeights[itemIndex][scoreIndex][factorIndex] = 0F; + } + } + } + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + positiveImplicitActs[factorIndex] = 0F; + negativeImplicitActs[factorIndex] = 0F; + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float[] socreProbabilities = new float[scoreSize]; + float[] factorProbabilities = new float[factorSize]; + float[] factorSums = new float[factorSize]; + // 用户历史分数记录? + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + for (VectorScalar term : userVector) { + int termIndex = term.getIndex(); + int scoreIndex = scoreIndexes.get(term.getValue()); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + factorSums[factorIndex] += weightProbabilities[termIndex][scoreIndex][factorIndex]; + } + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + factorProbabilities[factorIndex] = (float) (1F / (1F + Math.exp(0F - factorSums[factorIndex] - implicitBiasProbabilities[factorIndex]))); + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + socreProbabilities[scoreIndex] += factorProbabilities[factorIndex] * weightProbabilities[itemIndex][scoreIndex][factorIndex]; + } + } + float probabilitySum = 0F; + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + socreProbabilities[scoreIndex] = (float) (1F / (1F + Math.exp(0F - socreProbabilities[scoreIndex] - explicitBiasProbabilities[itemIndex][scoreIndex]))); + probabilitySum += socreProbabilities[scoreIndex]; + } + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + socreProbabilities[scoreIndex] /= probabilitySum; + } + float predict = 0F; + switch (predictionType) { + case MAX: + float score = 0F; + float probability = 0F; + for (Entry term : scoreIndexes.entrySet()) { + if (socreProbabilities[term.getValue()] > probability) { + probability = socreProbabilities[term.getValue()]; + score = term.getKey(); + } + } + predict = score; + break; + case MEAN: + float mean = 0f; + for (Entry term : scoreIndexes.entrySet()) { + mean += socreProbabilities[term.getValue()] * term.getKey(); + } + predict = mean; + break; + } + instance.setQuantityMark(predict); + } + + @Override + protected void eStep() { + // TODO Auto-generated method stub + + } + + @Override + protected void mStep() { + // TODO Auto-generated method stub + + } + +} + +enum PredictionType { + MAX, MEAN +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/RFRecModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/RFRecModel.java new file mode 100644 index 0000000..750b9f9 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/RFRecModel.java @@ -0,0 +1,199 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +import it.unimi.dsi.fastutil.floats.Float2IntLinkedOpenHashMap; +import it.unimi.dsi.fastutil.floats.FloatRBTreeSet; +import it.unimi.dsi.fastutil.floats.FloatSet; + +/** + * + * RF Rec推荐器 + * + *
+ * RF-Rec: Fast and Accurate Computation of Recommendations based on Rating Frequencies
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class RFRecModel extends MatrixFactorizationModel { + /** + * The average ratings of users + */ + private DenseVector userMeans; + + /** + * The average ratings of items + */ + private DenseVector itemMeans; + + /** 分数索引 (TODO 考虑取消或迁移.本质为连续特征离散化) */ + protected Float2IntLinkedOpenHashMap scoreIndexes; + + /** + * The number of ratings per rating value per user + */ + private DenseMatrix userScoreFrequencies; + + /** + * The number of ratings per rating value per item + */ + private DenseMatrix itemScoreFrequencies; + + /** + * User weights learned by the gradient solver + */ + private DenseVector userWeights; + + /** + * Item weights learned by the gradient solver. + */ + private DenseVector itemWeights; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // Calculate the average ratings + userMeans = DenseVector.valueOf(userSize); + itemMeans = DenseVector.valueOf(itemSize); + userWeights = DenseVector.valueOf(userSize); + itemWeights = DenseVector.valueOf(itemSize); + + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0) { + userMeans.setValue(userIndex, meanScore); + } else { + userMeans.setValue(userIndex, userVector.getSum(false) / userVector.getElementSize()); + } + userWeights.setValue(userIndex, 0.6F + RandomUtility.randomFloat(0.01F)); + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + if (itemVector.getElementSize() == 0) { + itemMeans.setValue(itemIndex, meanScore); + } else { + itemMeans.setValue(itemIndex, itemVector.getSum(false) / itemVector.getElementSize()); + } + itemWeights.setValue(itemIndex, 0.4F + RandomUtility.randomFloat(0.01F)); + } + + // TODO 此处会与scoreIndexes一起重构,本质为连续特征离散化. + FloatSet scores = new FloatRBTreeSet(); + for (MatrixScalar term : scoreMatrix) { + scores.add(term.getValue()); + } + scores.remove(0F); + scoreIndexes = new Float2IntLinkedOpenHashMap(); + int index = 0; + for (float score : scores) { + scoreIndexes.put(score, index++); + } + + // Calculate the frequencies. + // Users,items + userScoreFrequencies = DenseMatrix.valueOf(userSize, actionSize); + itemScoreFrequencies = DenseMatrix.valueOf(itemSize, actionSize); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + int scoreIndex = scoreIndexes.get(term.getValue()); + userScoreFrequencies.shiftValue(userIndex, scoreIndex, 1F); + itemScoreFrequencies.shiftValue(itemIndex, scoreIndex, 1F); + } + } + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float error = term.getValue() - predict(userIndex, itemIndex); + + // Gradient-Step on user weights. + float userWeight = userWeights.getValue(userIndex) + learnRatio * (error - userRegularization * userWeights.getValue(userIndex)); + userWeights.setValue(userIndex, userWeight); + + // Gradient-Step on item weights. + float itemWeight = itemWeights.getValue(itemIndex) + learnRatio * (error - itemRegularization * itemWeights.getValue(itemIndex)); + itemWeights.setValue(itemIndex, itemWeight); + } + } + } + + /** + * Returns 1 if the rating is similar to the rounded average value + * + * @param mean the average + * @param score the rating + * @return 1 when the values are equal + */ + private float isMean(float mean, int score) { + return Math.round(mean) == score ? 1F : 0F; + } + + @Override + protected float predict(int userIndex, int itemIndex) { + float value = meanScore; + float userSum = userScoreFrequencies.getRowVector(userIndex).getSum(false); + float itemSum = itemScoreFrequencies.getRowVector(itemIndex).getSum(false); + float userMean = userMeans.getValue(userIndex); + float itemMean = itemMeans.getValue(itemIndex); + + if (userSum > 0F && itemSum > 0F && userMean > 0F && itemMean > 0F) { + float numeratorUser = 0F; + float denominatorUser = 0F; + float numeratorItem = 0F; + float denominatorItem = 0F; + float frequency = 0F; + // Go through all the possible rating values + for (int scoreIndex = 0, scoreSize = scoreIndexes.size(); scoreIndex < scoreSize; scoreIndex++) { + // user component + frequency = userScoreFrequencies.getValue(userIndex, scoreIndex); + frequency = frequency + 1F + isMean(userMean, scoreIndex); + numeratorUser += frequency * scoreIndex; + denominatorUser += frequency; + + // item component + frequency = itemScoreFrequencies.getValue(itemIndex, scoreIndex); + frequency = frequency + 1F + isMean(itemMean, scoreIndex); + numeratorItem += frequency * scoreIndex; + denominatorItem += frequency; + } + + float userWeight = userWeights.getValue(userIndex); + float itemWeight = itemWeights.getValue(itemIndex); + value = userWeight * numeratorUser / denominatorUser + itemWeight * numeratorItem / denominatorItem; + } else { + // if the user or item weren't known in the training phase... + if (userSum == 0F || userMean == 0F) { + if (itemMean != 0F) { + return itemMean; + } else { + return meanScore; + } + } + if (itemSum == 0F || itemMean == 0F) { + if (userMean != 0F) { + return userMean; + } else { + // Some heuristic -> a bit above the average rating + return meanScore; + } + } + } + return value; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/SVDPlusPlusModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/SVDPlusPlusModel.java new file mode 100644 index 0000000..fe4fa7f --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/SVDPlusPlusModel.java @@ -0,0 +1,136 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; + +/** + * + * SVD++推荐器 + * + *
+ * Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class SVDPlusPlusModel extends BiasedMFModel { + /** + * item implicit feedback factors, "imp" string means implicit + */ + private DenseMatrix factorMatrix; + + /** + * implicit item regularization + */ + private float regImpItem; + + /* + * (non-Javadoc) + * + * @see net.librecommender.recommender.AbstractRecommender#setup() + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + regImpItem = configuration.getFloat("recommender.impItem.regularization", 0.015F); + factorMatrix = DenseMatrix.valueOf(itemSize, factorSize); + factorMatrix.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(distribution.sample().floatValue()); + }); + } + + @Override + protected void doPractice() { + DenseVector factorVector = DenseVector.valueOf(factorSize); + for (int epocheIndex = 10; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0) { + continue; + } + for (VectorScalar outerTerm : userVector) { + int itemIndex = outerTerm.getIndex(); + // TODO 此处可以修改为按userVector重置 + factorVector.setValues(0F); + for (VectorScalar innerTerm : userVector) { + factorVector.addVector(factorMatrix.getRowVector(innerTerm.getIndex())); + } + float scale = (float) Math.sqrt(userVector.getElementSize()); + if (scale > 0F) { + factorVector.scaleValues(1F / scale); + } + float error = outerTerm.getValue() - predict(userIndex, itemIndex, factorVector); + totalError += error * error; + // update user and item bias + float userBias = userBiases.getValue(userIndex); + userBiases.shiftValue(userIndex, learnRatio * (error - regBias * userBias)); + totalError += regBias * userBias * userBias; + float itemBias = itemBiases.getValue(itemIndex); + itemBiases.shiftValue(itemIndex, learnRatio * (error - regBias * itemBias)); + totalError += regBias * itemBias * itemBias; + + // update user and item factors + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float itemFactor = itemFactors.getValue(itemIndex, factorIndex); + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (error * itemFactor - userRegularization * userFactor)); + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * (error * (userFactor + factorVector.getValue(factorIndex)) - itemRegularization * itemFactor)); + totalError += userRegularization * userFactor * userFactor + itemRegularization * itemFactor * itemFactor; + for (VectorScalar innerTerm : userVector) { + int index = innerTerm.getIndex(); + float factor = factorMatrix.getValue(index, factorIndex); + factorMatrix.shiftValue(index, factorIndex, learnRatio * (error * itemFactor / scale - regImpItem * factor)); + totalError += regImpItem * factor * factor; + } + } + } + } + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + private float predict(int userIndex, int itemIndex, DenseVector factorVector) { + float value = userBiases.getValue(userIndex) + itemBiases.getValue(itemIndex) + meanScore; + // sum with user factors + for (int index = 0; index < factorSize; index++) { + value = value + (factorVector.getValue(index) + userFactors.getValue(userIndex, index)) * itemFactors.getValue(itemIndex, index); + } + return value; + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + // TODO 此处需要重构,取消DenseVector. + DenseVector factorVector = DenseVector.valueOf(factorSize); + // sum of implicit feedback factors of userIdx with weight Math.sqrt(1.0 + // / userItemsList.get(userIdx).size()) + for (VectorScalar term : userVector) { + factorVector.addVector(factorMatrix.getRowVector(term.getIndex())); + } + float scale = (float) Math.sqrt(userVector.getElementSize()); + if (scale > 0F) { + factorVector.scaleValues(1F / scale); + } + instance.setQuantityMark(predict(userIndex, itemIndex, factorVector)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/URPModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/URPModel.java new file mode 100644 index 0000000..e456909 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/URPModel.java @@ -0,0 +1,362 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Map.Entry; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.ProbabilisticGraphicalModel; +import com.jstarcraft.rns.utility.GammaUtility; +import com.jstarcraft.rns.utility.SampleUtility; + +import it.unimi.dsi.fastutil.ints.Int2IntRBTreeMap; +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * + * URP推荐器 + * + *
+ * User Rating Profile: a LDA model for rating prediction
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class URPModel extends ProbabilisticGraphicalModel { + + private float preRMSE; + + /** + * number of occurrentces of entry (user, topic) + */ + private DenseMatrix userTopicTimes; + + /** + * number of occurences of users + */ + private DenseVector userTopicNumbers; + + /** + * number of occurrences of entry (topic, item) + */ + private DenseMatrix topicItemNumbers; + + /** + * P(k | u) + */ + private DenseMatrix userTopicProbabilities, userTopicSums; + + /** + * user parameters + */ + private DenseVector alpha; + + /** + * item parameters + */ + private DenseVector beta; + + /** + * + */ + private Int2IntRBTreeMap topicAssignments; + + /** + * number of occurrences of entry (t, i, r) + */ + private int[][][] topicItemTimes; // Nkir + + /** + * cumulative statistics of probabilities of (t, i, r) + */ + private float[][][] topicItemScoreSums; // PkirSum; + + /** + * posterior probabilities of parameters phi_{k, i, r} + */ + private float[][][] topicItemScoreProbabilities; // Pkir; + + private DenseVector randomProbabilities; + + /** 学习矩阵与校验矩阵(TODO 将scoreMatrix划分) */ + private SparseMatrix learnMatrix, checkMatrix; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + float checkRatio = configuration.getFloat("recommender.urp.chech.ratio", 0F); + if (checkRatio == 0F) { + learnMatrix = scoreMatrix; + checkMatrix = null; + } else { + HashMatrix learnTable = new HashMatrix(true, userSize, itemSize, new Long2FloatRBTreeMap()); + HashMatrix checkTable = new HashMatrix(true, userSize, itemSize, new Long2FloatRBTreeMap()); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + if (RandomUtility.randomFloat(1F) <= checkRatio) { + checkTable.setValue(userIndex, itemIndex, score); + } else { + learnTable.setValue(userIndex, itemIndex, score); + } + } + learnMatrix = SparseMatrix.valueOf(userSize, itemSize, learnTable); + checkMatrix = SparseMatrix.valueOf(userSize, itemSize, checkTable); + } + + // cumulative parameters + userTopicSums = DenseMatrix.valueOf(userSize, factorSize); + topicItemScoreSums = new float[factorSize][itemSize][scoreSize]; + + // initialize count variables + userTopicTimes = DenseMatrix.valueOf(userSize, factorSize); + userTopicNumbers = DenseVector.valueOf(userSize); + + topicItemTimes = new int[factorSize][itemSize][scoreSize]; + topicItemNumbers = DenseMatrix.valueOf(factorSize, itemSize); + + float initAlpha = configuration.getFloat("recommender.pgm.bucm.alpha", 1F / factorSize); + alpha = DenseVector.valueOf(factorSize); + alpha.setValues(initAlpha); + + float initBeta = configuration.getFloat("recommender.pgm.bucm.beta", 1F / factorSize); + beta = DenseVector.valueOf(scoreSize); + beta.setValues(initBeta); + + // initialize topics + topicAssignments = new Int2IntRBTreeMap(); + for (MatrixScalar term : learnMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + int scoreIndex = scoreIndexes.get(score); // rating level 0 ~ + // numLevels + int topicIndex = RandomUtility.randomInteger(factorSize); // 0 + // ~ + // k-1 + + // Assign a topic t to pair (u, i) + topicAssignments.put(userIndex * itemSize + itemIndex, topicIndex); + // number of pairs (u, t) in (u, i, t) + userTopicTimes.shiftValue(userIndex, topicIndex, 1); + // total number of items of user u + userTopicNumbers.shiftValue(userIndex, 1); + + // number of pairs (t, i, r) + topicItemTimes[topicIndex][itemIndex][scoreIndex]++; + // total number of words assigned to topic t + topicItemNumbers.shiftValue(topicIndex, itemIndex, 1); + } + + randomProbabilities = DenseVector.valueOf(factorSize); + } + + @Override + protected void eStep() { + float sumAlpha = alpha.getSum(false); + float sumBeta = beta.getSum(false); + + // collapse Gibbs sampling + for (MatrixScalar term : learnMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + int scoreIndex = scoreIndexes.get(score); // rating level 0 ~ + // numLevels + int assignmentIndex = topicAssignments.get(userIndex * itemSize + itemIndex); + + userTopicTimes.shiftValue(userIndex, assignmentIndex, -1); + userTopicNumbers.shiftValue(userIndex, -1); + topicItemTimes[assignmentIndex][itemIndex][scoreIndex]--; + topicItemNumbers.shiftValue(assignmentIndex, itemIndex, -1); + + // 计算概率 + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + randomProbabilities.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = (userTopicTimes.getValue(userIndex, index) + alpha.getValue(index)) / (userTopicNumbers.getValue(userIndex) + sumAlpha) * (topicItemTimes[index][itemIndex][scoreIndex] + beta.getValue(scoreIndex)) / (topicItemNumbers.getValue(index, itemIndex) + sumBeta); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + assignmentIndex = SampleUtility.binarySearch(randomProbabilities, 0, randomProbabilities.getElementSize() - 1, RandomUtility.randomFloat(sum.getValue())); + + // new topic t + topicAssignments.put(userIndex * itemSize + itemIndex, assignmentIndex); + + // add newly estimated z_i to count variables + userTopicTimes.shiftValue(userIndex, assignmentIndex, 1); + userTopicNumbers.shiftValue(userIndex, 1); + topicItemTimes[assignmentIndex][itemIndex][scoreIndex]++; + topicItemNumbers.shiftValue(assignmentIndex, itemIndex, 1); + } + + } + + /** + * Thomas P. Minka, Estimating a Dirichlet distribution, see Eq.(55) + */ + @Override + protected void mStep() { + float denominator; + float value; + + // update alpha vector + float alphaSum = alpha.getSum(false); + float alphaDigamma = GammaUtility.digamma(alphaSum); + float alphaValue; + denominator = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + value = userTopicNumbers.getValue(userIndex); + if (value != 0F) { + denominator += GammaUtility.digamma(value + alphaSum) - alphaDigamma; + } + } + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + alphaValue = alpha.getValue(topicIndex); + alphaDigamma = GammaUtility.digamma(alphaValue); + float numerator = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + value = userTopicTimes.getValue(userIndex, topicIndex); + if (value != 0F) { + numerator += GammaUtility.digamma(value + alphaValue) - alphaDigamma; + } + } + if (numerator != 0F) { + alpha.setValue(topicIndex, alphaValue * (numerator / denominator)); + } + } + + // update beta_k + float betaSum = beta.getSum(false); + float betaDigamma = GammaUtility.digamma(betaSum); + float betaValue; + denominator = 0F; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = topicItemNumbers.getValue(topicIndex, itemIndex); + if (value != 0F) { + denominator += GammaUtility.digamma(value + betaSum) - betaDigamma; + } + } + } + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + betaValue = beta.getValue(scoreIndex); + betaDigamma = GammaUtility.digamma(betaValue); + float numerator = 0F; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = topicItemTimes[topicIndex][itemIndex][scoreIndex]; + if (value != 0F) { + numerator += GammaUtility.digamma(value + betaValue) - betaDigamma; + } + } + } + if (numerator != 0F) { + beta.setValue(scoreIndex, betaValue * (numerator / denominator)); + } + } + } + + protected void readoutParameters() { + float value = 0F; + float sumAlpha = alpha.getSum(false); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + value = (userTopicTimes.getValue(userIndex, topicIndex) + alpha.getValue(topicIndex)) / (userTopicNumbers.getValue(userIndex) + sumAlpha); + userTopicSums.shiftValue(userIndex, topicIndex, value); + } + } + float sumBeta = beta.getSum(false); + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + value = (topicItemTimes[topicIndex][itemIndex][scoreIndex] + beta.getValue(scoreIndex)) / (topicItemNumbers.getValue(topicIndex, itemIndex) + sumBeta); + topicItemScoreSums[topicIndex][itemIndex][scoreIndex] += value; + } + } + } + numberOfStatistics++; + } + + @Override + protected void estimateParameters() { + userTopicProbabilities = DenseMatrix.copyOf(userTopicSums); + userTopicProbabilities.scaleValues(1F / numberOfStatistics); + topicItemScoreProbabilities = new float[factorSize][itemSize][scoreSize]; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + topicItemScoreProbabilities[topicIndex][itemIndex][scoreIndex] = topicItemScoreSums[topicIndex][itemIndex][scoreIndex] / numberOfStatistics; + } + } + } + } + + @Override + protected boolean isConverged(int iter) { + // TODO 此处使用validMatrix似乎更合理. + if (checkMatrix == null) { + return false; + } + // get posterior probability distribution first + estimateParameters(); + // compute current RMSE + int count = 0; + float sum = 0F; + // TODO 此处使用validMatrix似乎更合理. + for (MatrixScalar term : checkMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + float predict = predict(userIndex, itemIndex); + if (Float.isNaN(predict)) { + continue; + } + float error = score - predict; + sum += error * error; + count++; + } + float rmse = (float) Math.sqrt(sum / count); + float delta = rmse - preRMSE; + if (numberOfStatistics > 1 && delta > 0F) { + return true; + } + preRMSE = rmse; + return false; + } + + private float predict(int userIndex, int itemIndex) { + float value = 0F; + for (Entry term : scoreIndexes.entrySet()) { + float score = term.getKey(); + float probability = 0F; + for (int topicIndex = 0; topicIndex < factorSize; topicIndex++) { + probability += userTopicProbabilities.getValue(userIndex, topicIndex) * topicItemScoreProbabilities[topicIndex][itemIndex][term.getValue()]; + } + value += probability * score; + } + return value; + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/collaborative/rating/UserKNNRatingModel.java b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/UserKNNRatingModel.java new file mode 100644 index 0000000..3637715 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/collaborative/rating/UserKNNRatingModel.java @@ -0,0 +1,79 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Iterator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.rns.model.collaborative.UserKNNModel; + +/** + * + * User KNN推荐器 + * + *
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class UserKNNRatingModel extends UserKNNModel { + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + SparseVector itemVector = itemVectors[itemIndex]; + MathVector neighbors = userNeighbors[userIndex]; + if (itemVector.getElementSize() == 0 || neighbors.getElementSize() == 0) { + instance.setQuantityMark(meanScore); + return; + } + + float sum = 0F, absolute = 0F; + int count = 0; + int leftCursor = 0, rightCursor = 0, leftSize = itemVector.getElementSize(), rightSize = neighbors.getElementSize(); + Iterator leftIterator = itemVector.iterator(); + VectorScalar leftTerm = leftIterator.next(); + Iterator rightIterator = neighbors.iterator(); + VectorScalar rightTerm = rightIterator.next(); + // 判断两个有序数组中是否存在相同的数字 + while (leftCursor < leftSize && rightCursor < rightSize) { + if (leftTerm.getIndex() == rightTerm.getIndex()) { + count++; + float correlation = rightTerm.getValue(); + float score = leftTerm.getValue(); + sum += correlation * (score - userMeans.getValue(rightTerm.getIndex())); + absolute += Math.abs(correlation); + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + leftCursor++; + rightCursor++; + } else if (leftTerm.getIndex() > rightTerm.getIndex()) { + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + rightCursor++; + } else if (leftTerm.getIndex() < rightTerm.getIndex()) { + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + leftCursor++; + } + } + + if (count == 0) { + instance.setQuantityMark(meanScore); + return; + } + + instance.setQuantityMark(absolute > 0 ? userMeans.getValue(userIndex) + sum / absolute : meanScore); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/content/EFMModel.java b/src/main/java/com/jstarcraft/rns/model/content/EFMModel.java new file mode 100644 index 0000000..8c2d250 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/content/EFMModel.java @@ -0,0 +1,361 @@ +package com.jstarcraft.rns.model.content; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.attribute.MemoryQualityAttribute; +import com.jstarcraft.ai.math.MathUtility; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.ArrayVector; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.core.utility.StringUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * + * EFM推荐器 + * + *
+ * Explicit factor models for explainable recommendation based on phrase-level sentiment analysis
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public abstract class EFMModel extends MatrixFactorizationModel { + + protected String commentField; + protected int commentDimension; + protected int numberOfFeatures; + protected int numberOfExplicitFeatures; + protected int numberOfImplicitFeatures; + protected float scoreScale; + protected DenseMatrix featureFactors; + protected DenseMatrix userExplicitFactors; + protected DenseMatrix userImplicitFactors; + protected DenseMatrix itemExplicitFactors; + protected DenseMatrix itemImplicitFactors; + protected SparseMatrix userFeatures; + protected SparseMatrix itemFeatures; + protected float attentionRegularization; + protected float qualityRegularization; + protected float explicitRegularization; + protected float implicitRegularization; + protected float featureRegularization; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + commentField = configuration.getString("data.model.fields.comment"); + commentDimension = model.getQualityInner(commentField); + MemoryQualityAttribute attribute = (MemoryQualityAttribute) space.getQualityAttribute(commentField); + Object[] wordValues = attribute.getDatas(); + + scoreScale = maximumScore - minimumScore; + numberOfExplicitFeatures = configuration.getInteger("recommender.factor.explicit", 5); + numberOfImplicitFeatures = factorSize - numberOfExplicitFeatures; + attentionRegularization = configuration.getFloat("recommender.regularization.lambdax", 0.001F); + qualityRegularization = configuration.getFloat("recommender.regularization.lambday", 0.001F); + explicitRegularization = configuration.getFloat("recommender.regularization.lambdau", 0.001F); + implicitRegularization = configuration.getFloat("recommender.regularization.lambdah", 0.001F); + featureRegularization = configuration.getFloat("recommender.regularization.lambdav", 0.001F); + + Map featureDictionaries = new HashMap<>(); + Map userDictionaries = new HashMap<>(); + Map itemDictionaries = new HashMap<>(); + + numberOfFeatures = 0; + // // TODO 此处保证所有特征都会被识别 + // for (Object value : wordValues) { + // String wordValue = (String) value; + // String[] words = wordValue.split(" "); + // for (String word : words) { + // // TODO 此处似乎是Bug,不应该再将word划分为更细粒度. + // String feature = word.split(":")[0]; + // if (!featureDictionaries.containsKey(feature) && + // StringUtils.isNotEmpty(feature)) { + // featureDictionaries.put(feature, numberOfWords); + // numberOfWords++; + // } + // } + // } + + for (DataInstance sample : model) { + int userIndex = sample.getQualityFeature(userDimension); + int itemIndex = sample.getQualityFeature(itemDimension); + int wordIndex = sample.getQualityFeature(commentDimension); + String wordValue = (String) wordValues[wordIndex]; + String[] words = wordValue.split(" "); + StringBuilder buffer; + for (String word : words) { + // TODO 此处似乎是Bug,不应该再将word划分为更细粒度. + String feature = word.split(":")[0]; + if (!featureDictionaries.containsKey(feature) && !StringUtility.isEmpty(feature)) { + featureDictionaries.put(feature, numberOfFeatures++); + } + buffer = userDictionaries.get(userIndex); + if (buffer != null) { + buffer.append(" ").append(word); + } else { + userDictionaries.put(userIndex, new StringBuilder(word)); + } + buffer = itemDictionaries.get(itemIndex); + if (buffer != null) { + buffer.append(" ").append(word); + } else { + itemDictionaries.put(itemIndex, new StringBuilder(word)); + } + } + } + + // Create V,U1,H1,U2,H2 + featureFactors = DenseMatrix.valueOf(numberOfFeatures, numberOfExplicitFeatures); + featureFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(0.01F)); + }); + userExplicitFactors = DenseMatrix.valueOf(userSize, numberOfExplicitFeatures); + userExplicitFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + userImplicitFactors = DenseMatrix.valueOf(userSize, numberOfImplicitFeatures); + userImplicitFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + itemExplicitFactors = DenseMatrix.valueOf(itemSize, numberOfExplicitFeatures); + itemExplicitFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + itemImplicitFactors = DenseMatrix.valueOf(itemSize, numberOfImplicitFeatures); + itemImplicitFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + + float[] featureValues = new float[numberOfFeatures]; + + // compute UserFeatureAttention + HashMatrix userTable = new HashMatrix(true, userSize, numberOfFeatures, new Long2FloatRBTreeMap()); + for (Entry term : userDictionaries.entrySet()) { + int userIndex = term.getKey(); + String[] words = term.getValue().toString().split(" "); + for (String word : words) { + if (!StringUtility.isEmpty(word)) { + int featureIndex = featureDictionaries.get(word.split(":")[0]); + featureValues[featureIndex] += 1F; + } + } + for (int featureIndex = 0; featureIndex < numberOfFeatures; featureIndex++) { + if (featureValues[featureIndex] != 0F) { + float value = (float) (1F + (scoreScale - 1F) * (2F / (1F + Math.exp(-featureValues[featureIndex])) - 1F)); + userTable.setValue(userIndex, featureIndex, value); + featureValues[featureIndex] = 0F; + } + } + } + userFeatures = SparseMatrix.valueOf(userSize, numberOfFeatures, userTable); + // compute ItemFeatureQuality + HashMatrix itemTable = new HashMatrix(true, itemSize, numberOfFeatures, new Long2FloatRBTreeMap()); + for (Entry term : itemDictionaries.entrySet()) { + int itemIndex = term.getKey(); + String[] words = term.getValue().toString().split(" "); + for (String word : words) { + if (!StringUtility.isEmpty(word)) { + int featureIndex = featureDictionaries.get(word.split(":")[0]); + featureValues[featureIndex] += Double.parseDouble(word.split(":")[1]); + } + } + for (int featureIndex = 0; featureIndex < numberOfFeatures; featureIndex++) { + if (featureValues[featureIndex] != 0F) { + float value = (float) (1F + (scoreScale - 1F) / (1F + Math.exp(-featureValues[featureIndex]))); + itemTable.setValue(itemIndex, featureIndex, value); + featureValues[featureIndex] = 0F; + } + } + } + itemFeatures = SparseMatrix.valueOf(itemSize, numberOfFeatures, itemTable); + + logger.info("numUsers:" + userSize); + logger.info("numItems:" + itemSize); + logger.info("numFeatures:" + numberOfFeatures); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int featureIndex = 0; featureIndex < numberOfFeatures; featureIndex++) { + if (userFeatures.getColumnScope(featureIndex) > 0 && itemFeatures.getColumnScope(featureIndex) > 0) { + SparseVector userVector = userFeatures.getColumnVector(featureIndex); + SparseVector itemVector = itemFeatures.getColumnVector(featureIndex); + // TODO 此处需要重构,应该避免不断构建SparseVector. + int feature = featureIndex; + ArrayVector userFactors = new ArrayVector(userVector); + userFactors.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(predictUserFactor(scalar, element.getIndex(), feature)); + }); + ArrayVector itemFactors = new ArrayVector(itemVector); + itemFactors.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(predictItemFactor(scalar, element.getIndex(), feature)); + }); + for (int factorIndex = 0; factorIndex < numberOfExplicitFeatures; factorIndex++) { + DenseVector factorUsersVector = userExplicitFactors.getColumnVector(factorIndex); + DenseVector factorItemsVector = itemExplicitFactors.getColumnVector(factorIndex); + float numerator = attentionRegularization * scalar.dotProduct(factorUsersVector, userVector).getValue() + qualityRegularization * scalar.dotProduct(factorItemsVector, itemVector).getValue(); + float denominator = attentionRegularization * scalar.dotProduct(factorUsersVector, userFactors).getValue() + qualityRegularization * scalar.dotProduct(factorItemsVector, itemFactors).getValue() + featureRegularization * featureFactors.getValue(featureIndex, factorIndex) + MathUtility.EPSILON; + featureFactors.setValue(featureIndex, factorIndex, (float) (featureFactors.getValue(featureIndex, factorIndex) * Math.sqrt(numerator / denominator))); + } + } + } + + // Update UserFeatureMatrix by fixing the others + for (int userIndex = 0; userIndex < userSize; userIndex++) { + if (scoreMatrix.getRowScope(userIndex) > 0 && userFeatures.getRowScope(userIndex) > 0) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + SparseVector attentionVector = userFeatures.getRowVector(userIndex); + // TODO 此处需要重构,应该避免不断构建SparseVector. + int user = userIndex; + ArrayVector itemPredictsVector = new ArrayVector(userVector); + itemPredictsVector.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(predict(user, element.getIndex())); + }); + ArrayVector attentionPredVector = new ArrayVector(attentionVector); + attentionPredVector.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(predictUserFactor(scalar, user, element.getIndex())); + }); + for (int factorIndex = 0; factorIndex < numberOfExplicitFeatures; factorIndex++) { + DenseVector factorItemsVector = itemExplicitFactors.getColumnVector(factorIndex); + DenseVector featureVector = featureFactors.getColumnVector(factorIndex); + float numerator = scalar.dotProduct(factorItemsVector, userVector).getValue() + attentionRegularization * scalar.dotProduct(featureVector, attentionVector).getValue(); + float denominator = scalar.dotProduct(factorItemsVector, itemPredictsVector).getValue() + attentionRegularization * scalar.dotProduct(featureVector, attentionPredVector).getValue() + explicitRegularization * userExplicitFactors.getValue(userIndex, factorIndex) + MathUtility.EPSILON; + userExplicitFactors.setValue(userIndex, factorIndex, (float) (userExplicitFactors.getValue(userIndex, factorIndex) * Math.sqrt(numerator / denominator))); + } + } + } + + // Update ItemFeatureMatrix by fixing the others + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + if (scoreMatrix.getColumnScope(itemIndex) > 0 && itemFeatures.getRowScope(itemIndex) > 0) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + SparseVector qualityVector = itemFeatures.getRowVector(itemIndex); + // TODO 此处需要重构,应该避免不断构建SparseVector. + int item = itemIndex; + ArrayVector userPredictsVector = new ArrayVector(itemVector); + userPredictsVector.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(predict(element.getIndex(), item)); + }); + ArrayVector qualityPredVector = new ArrayVector(qualityVector); + qualityPredVector.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(predictItemFactor(scalar, item, element.getIndex())); + }); + for (int factorIndex = 0; factorIndex < numberOfExplicitFeatures; factorIndex++) { + DenseVector factorUsersVector = userExplicitFactors.getColumnVector(factorIndex); + DenseVector featureVector = featureFactors.getColumnVector(factorIndex); + float numerator = scalar.dotProduct(factorUsersVector, itemVector).getValue() + qualityRegularization * scalar.dotProduct(featureVector, qualityVector).getValue(); + float denominator = scalar.dotProduct(factorUsersVector, userPredictsVector).getValue() + qualityRegularization * scalar.dotProduct(featureVector, qualityPredVector).getValue() + explicitRegularization * itemExplicitFactors.getValue(itemIndex, factorIndex) + MathUtility.EPSILON; + itemExplicitFactors.setValue(itemIndex, factorIndex, (float) (itemExplicitFactors.getValue(itemIndex, factorIndex) * Math.sqrt(numerator / denominator))); + } + } + } + + // Update UserHiddenMatrix by fixing the others + for (int userIndex = 0; userIndex < userSize; userIndex++) { + if (scoreMatrix.getRowScope(userIndex) > 0) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + // TODO 此处需要重构,应该避免不断构建SparseVector. + int user = userIndex; + ArrayVector itemPredictsVector = new ArrayVector(userVector); + itemPredictsVector.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(predict(user, element.getIndex())); + }); + for (int factorIndex = 0; factorIndex < numberOfImplicitFeatures; factorIndex++) { + DenseVector hiddenItemsVector = itemImplicitFactors.getColumnVector(factorIndex); + float numerator = scalar.dotProduct(hiddenItemsVector, userVector).getValue(); + float denominator = scalar.dotProduct(hiddenItemsVector, itemPredictsVector).getValue() + implicitRegularization * userImplicitFactors.getValue(userIndex, factorIndex) + MathUtility.EPSILON; + userImplicitFactors.setValue(userIndex, factorIndex, (float) (userImplicitFactors.getValue(userIndex, factorIndex) * Math.sqrt(numerator / denominator))); + } + } + } + + // Update ItemHiddenMatrix by fixing the others + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + if (scoreMatrix.getColumnScope(itemIndex) > 0) { + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + // TODO 此处需要重构,应该避免不断构建SparseVector. + int item = itemIndex; + ArrayVector userPredictsVector = new ArrayVector(itemVector); + userPredictsVector.iterateElement(MathCalculator.SERIAL, (element) -> { + element.setValue(predict(element.getIndex(), item)); + }); + for (int factorIndex = 0; factorIndex < numberOfImplicitFeatures; factorIndex++) { + DenseVector hiddenUsersVector = userImplicitFactors.getColumnVector(factorIndex); + float numerator = scalar.dotProduct(hiddenUsersVector, itemVector).getValue(); + float denominator = scalar.dotProduct(hiddenUsersVector, userPredictsVector).getValue() + implicitRegularization * itemImplicitFactors.getValue(itemIndex, factorIndex) + MathUtility.EPSILON; + itemImplicitFactors.setValue(itemIndex, factorIndex, (float) (itemImplicitFactors.getValue(itemIndex, factorIndex) * Math.sqrt(numerator / denominator))); + } + } + } + + // Compute loss value + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + double rating = term.getValue(); + double predRating = scalar.dotProduct(userExplicitFactors.getRowVector(userIndex), itemExplicitFactors.getRowVector(itemIndex)).getValue() + scalar.dotProduct(userImplicitFactors.getRowVector(userIndex), itemImplicitFactors.getRowVector(itemIndex)).getValue(); + totalError += (rating - predRating) * (rating - predRating); + } + + for (MatrixScalar term : userFeatures) { + int userIndex = term.getRow(); + int featureIndex = term.getColumn(); + double real = term.getValue(); + double pred = predictUserFactor(scalar, userIndex, featureIndex); + totalError += (real - pred) * (real - pred); + } + + for (MatrixScalar term : itemFeatures) { + int itemIndex = term.getRow(); + int featureIndex = term.getColumn(); + double real = term.getValue(); + double pred = predictItemFactor(scalar, itemIndex, featureIndex); + totalError += (real - pred) * (real - pred); + } + + totalError += explicitRegularization * (userExplicitFactors.getNorm(2F, false) + itemExplicitFactors.getNorm(2F, false)); + totalError += implicitRegularization * (userImplicitFactors.getNorm(2F, false) + itemImplicitFactors.getNorm(2F, false)); + totalError += featureRegularization * featureFactors.getNorm(2F, false); + + logger.info("iter:" + epocheIndex + ", loss:" + totalError); + } + } + + protected float predictUserFactor(DefaultScalar scalar, int userIndex, int featureIndex) { + return scalar.dotProduct(userExplicitFactors.getRowVector(userIndex), featureFactors.getRowVector(featureIndex)).getValue(); + } + + protected float predictItemFactor(DefaultScalar scalar, int itemIndex, int featureIndex) { + return scalar.dotProduct(itemExplicitFactors.getRowVector(itemIndex), featureFactors.getRowVector(featureIndex)).getValue(); + } + + @Override + protected float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + return scalar.dotProduct(userExplicitFactors.getRowVector(userIndex), itemExplicitFactors.getRowVector(itemIndex)).getValue() + scalar.dotProduct(userImplicitFactors.getRowVector(userIndex), itemImplicitFactors.getRowVector(itemIndex)).getValue(); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/content/content.txt b/src/main/java/com/jstarcraft/rns/model/content/content.txt new file mode 100644 index 0000000..928e044 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/content/content.txt @@ -0,0 +1,8 @@ +基于内容的推荐(Content-based Recommendations): +http://breezedeus.github.io/2012/04/10/breezedeus-content-based-rec.html + +Science Concierge: A Fast Content-Based Recommendation System for Scientific Publications: +https://journals.plos.org/plosone/article?id=10.1371%2Fjournal.pone.0158423 + +a Python repository for content-based recommendation based on Latent semantic analysis (LSA) topic distance and Rocchio Algorithm, see the implementation interactively on http://www.scholarfy.net: +https://github.com/titipata/science_concierge diff --git a/src/main/java/com/jstarcraft/rns/model/content/ranking/EFMRankingModel.java b/src/main/java/com/jstarcraft/rns/model/content/ranking/EFMRankingModel.java new file mode 100644 index 0000000..697cbae --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/content/ranking/EFMRankingModel.java @@ -0,0 +1,71 @@ +package com.jstarcraft.rns.model.content.ranking; + +import java.util.Arrays; +import java.util.Comparator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.content.EFMModel; + +/** + * + * EFM推荐器 + * + *
+ * Explicit factor models for explainable recommendation based on phrase-level sentiment analysis
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class EFMRankingModel extends EFMModel { + + private float threshold; + + private int featureLimit; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + threshold = configuration.getFloat("efmranking.threshold", 1F); + featureLimit = configuration.getInteger("efmranking.featureLimit", 250); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + + // TODO 此处可以优化性能 + Integer[] orderIndexes = new Integer[numberOfFeatures]; + for (int featureIndex = 0; featureIndex < numberOfFeatures; featureIndex++) { + orderIndexes[featureIndex] = featureIndex; + } + MathVector vector = DenseVector.valueOf(numberOfFeatures); + vector.dotProduct(userExplicitFactors.getRowVector(userIndex), featureFactors, true, MathCalculator.SERIAL); + Arrays.sort(orderIndexes, new Comparator() { + @Override + public int compare(Integer leftIndex, Integer rightIndex) { + return (vector.getValue(leftIndex) > vector.getValue(rightIndex) ? -1 : (vector.getValue(leftIndex) < vector.getValue(rightIndex) ? 1 : 0)); + } + }); + + float value = 0F; + for (int index = 0; index < featureLimit; index++) { + int featureIndex = orderIndexes[index]; + value += predictUserFactor(scalar, userIndex, featureIndex) * predictItemFactor(scalar, itemIndex, featureIndex); + } + value = threshold * (value / (featureLimit * maximumScore)); + value = value + (1F - threshold) * predict(userIndex, itemIndex); + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/content/ranking/TFIDFModel.java b/src/main/java/com/jstarcraft/rns/model/content/ranking/TFIDFModel.java new file mode 100644 index 0000000..cd1e63f --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/content/ranking/TFIDFModel.java @@ -0,0 +1,187 @@ +package com.jstarcraft.rns.model.content.ranking; + +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.algorithm.text.AbstractTermFrequency; +import com.jstarcraft.ai.math.algorithm.text.InverseDocumentFrequency; +import com.jstarcraft.ai.math.algorithm.text.NaturalInverseDocumentFrequency; +import com.jstarcraft.ai.math.algorithm.text.TermFrequency; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.ArrayVector; +import com.jstarcraft.ai.math.structure.vector.HashVector; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.Integer2FloatKeyValue; +import com.jstarcraft.core.utility.Neighborhood; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +import it.unimi.dsi.fastutil.ints.Int2FloatAVLTreeMap; +import it.unimi.dsi.fastutil.ints.Int2FloatSortedMap; +import it.unimi.dsi.fastutil.longs.Long2FloatAVLTreeMap; +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * + * TF-IDF推荐器 + * + * @author Birdy + * + */ +public class TFIDFModel extends MatrixFactorizationModel { + + private Comparator comparator = new Comparator() { + + @Override + public int compare(Integer2FloatKeyValue left, Integer2FloatKeyValue right) { + int compare = -(Float.compare(left.getValue(), right.getValue())); + if (compare == 0) { + compare = Integer.compare(left.getKey(), right.getKey()); + } + return compare; + } + + }; + + protected String commentField; + protected int commentDimension; + + protected ArrayVector[] userVectors; + protected SparseMatrix itemVectors; + +// protected MathCorrelation correlation; + + private class VectorTermFrequency extends AbstractTermFrequency { + + public VectorTermFrequency(MathVector vector) { + super(new Int2FloatAVLTreeMap(), vector.getElementSize()); + + for (VectorScalar scalar : vector) { + keyValues.put(scalar.getIndex(), scalar.getValue()); + } + } + + } + + private class DocumentIterator implements Iterator { + + private int index = 0; + + @Override + public boolean hasNext() { + return index < itemVectors.getRowSize(); + } + + @Override + public TermFrequency next() { + MathVector vector = itemVectors.getRowVector(index++); + VectorTermFrequency termFrequency = new VectorTermFrequency(vector); + return termFrequency; + } + + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + int numberOfFeatures = 4096; + + // 特征矩阵 + HashMatrix featureMatrix = new HashMatrix(true, itemSize, numberOfFeatures, new Long2FloatRBTreeMap()); + DataModule featureModel = space.getModule("article"); + String articleField = configuration.getString("data.model.fields.article"); + String featureField = configuration.getString("data.model.fields.feature"); + String degreeField = configuration.getString("data.model.fields.degree"); + int articleDimension = featureModel.getQualityInner(articleField); + int featureDimension = featureModel.getQualityInner(featureField); + int degreeDimension = featureModel.getQuantityInner(degreeField); + for (DataInstance instance : featureModel) { + int itemIndex = instance.getQualityFeature(articleDimension); + int featureIndex = instance.getQualityFeature(featureDimension); + float featureValue = instance.getQuantityFeature(degreeDimension); + featureMatrix.setValue(itemIndex, featureIndex, featureValue); + } + + // 物品矩阵 + itemVectors = SparseMatrix.valueOf(itemSize, numberOfFeatures, featureMatrix); + DocumentIterator iterator = new DocumentIterator(); + Int2FloatSortedMap keyValues = new Int2FloatAVLTreeMap(); + InverseDocumentFrequency inverseDocumentFrequency = new NaturalInverseDocumentFrequency(keyValues, iterator); + /** k控制着词频饱和度,值越小饱和度变化越快,值越大饱和度变化越慢 */ + float k = 1.2F; + /** b控制着词频归一化所起的作用,0.0会完全禁用归一化,1.0会完全启用归一化 */ + float b = 0.75F; + float avgdl = 0F; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + MathVector itemVector = itemVectors.getRowVector(itemIndex); + avgdl += itemVector.getElementSize(); + } + avgdl /= itemSize; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + MathVector itemVector = itemVectors.getRowVector(itemIndex); + float l = itemVector.getElementSize() / avgdl; + for (VectorScalar scalar : itemVector) { + float tf = scalar.getValue(); + float idf = inverseDocumentFrequency.getValue(scalar.getIndex()); + // use BM25 +// scalar.setValue((idf * (k + 1F) * tf) / (k * (1F - b + b * l) + tf)); + // use TF-IDF + scalar.setValue((idf * tf)); + } + // 归一化 + itemVector.scaleValues(1F / itemVector.getNorm(2F, true)); + } + + // 用户矩阵 + userVectors = new ArrayVector[userSize]; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + MathVector rowVector = scoreMatrix.getRowVector(userIndex); + HashVector userVector = new HashVector(0L, numberOfFeatures, new Long2FloatAVLTreeMap()); + for (VectorScalar scalar : rowVector) { + int itemIndex = scalar.getIndex(); + MathVector itemVector = itemVectors.getRowVector(itemIndex); + for (int position = 0; position < itemVector.getElementSize(); position++) { + float value = userVector.getValue(itemVector.getIndex(position)); + userVector.setValue(itemVector.getIndex(position), Float.isNaN(value) ? itemVector.getValue(position) : value + itemVector.getValue(position)); + } + } + userVector.scaleValues(1F / rowVector.getElementSize()); + Neighborhood knn = new Neighborhood(50, comparator); + for (int position = 0; position < userVector.getElementSize(); position++) { + knn.updateNeighbor(new Integer2FloatKeyValue(userVector.getIndex(position), userVector.getValue(position))); + } + userVector = new HashVector(0L, numberOfFeatures, new Long2FloatAVLTreeMap()); + Collection neighbors = knn.getNeighbors(); + for (Integer2FloatKeyValue neighbor : neighbors) { + userVector.setValue(neighbor.getKey(), neighbor.getValue()); + } + userVectors[userIndex] = new ArrayVector(userVector); + } + } + + @Override + protected void doPractice() { + } + + @Override + protected float predict(int userIndex, int itemIndex) { + MathVector userVector = userVectors[userIndex]; + MathVector itemVector = itemVectors.getRowVector(itemIndex); + return DefaultScalar.getInstance().dotProduct(userVector, itemVector).getValue(); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/content/rating/EFMRatingModel.java b/src/main/java/com/jstarcraft/rns/model/content/rating/EFMRatingModel.java new file mode 100644 index 0000000..95953fb --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/content/rating/EFMRatingModel.java @@ -0,0 +1,28 @@ +package com.jstarcraft.rns.model.content.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.rns.model.content.EFMModel; + +/** + * + * User KNN推荐器 + * + *
+ * Explicit factor models for explainable recommendation based on phrase-level sentiment analysis
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class EFMRatingModel extends EFMModel { + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = predict(userIndex, itemIndex); + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/content/rating/HFTModel.java b/src/main/java/com/jstarcraft/rns/model/content/rating/HFTModel.java new file mode 100644 index 0000000..6059ca1 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/content/rating/HFTModel.java @@ -0,0 +1,380 @@ +package com.jstarcraft.rns.model.content.rating; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.attribute.MemoryQualityAttribute; +import com.jstarcraft.ai.math.MathUtility; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.model.neuralnetwork.activation.ActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.activation.SoftMaxActivationFunction; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.utility.SampleUtility; + +import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap; + +/** + * + * HFT推荐器 + * + *
+ * Hidden factors and hidden topics: understanding rating dimensions with review text
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class HFTModel extends MatrixFactorizationModel { + + private static class Content { + + private int[] wordIndexes; + + private int[] topicIndexes; + + private Content(int[] wordIndexes) { + this.wordIndexes = wordIndexes; + } + + int[] getWordIndexes() { + return wordIndexes; + } + + int[] getTopicIndexes() { + return topicIndexes; + } + + void setTopicIndexes(int[] topicIndexes) { + this.topicIndexes = topicIndexes; + } + + } + + // TODO 考虑重构 + private Int2ObjectRBTreeMap contentMatrix; + + private DenseMatrix wordFactors; + + protected String commentField; + protected int commentDimension; + /** 单词数量(TODO 考虑改名为numWords) */ + private int numberOfWords; + /** + * user biases + */ + private DenseVector userBiases; + + /** + * user biases + */ + private DenseVector itemBiases; + /** + * user latent factors + */ + // TODO 取消,父类已实现. + private DenseMatrix userFactors; + + /** + * item latent factors + */ + // TODO 取消,父类已实现. + private DenseMatrix itemFactors; + /** + * init mean + */ + // TODO 取消,父类已实现. + private float initMean; + + /** + * init standard deviation + */ + // TODO 取消,父类已实现. + private float initStd; + /** + * bias regularization + */ + private float biasRegularization; + /** + * user regularization + */ + // TODO 取消,父类已实现. + private float userRegularization; + + /** + * item regularization + */ + // TODO 取消,父类已实现. + private float itemRegularization; + + private DenseVector probability; + + private DenseMatrix userProbabilities; + private DenseMatrix wordProbabilities; + + protected ActivationFunction function; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + commentField = configuration.getString("data.model.fields.comment"); + commentDimension = model.getQualityInner(commentField); + MemoryQualityAttribute attribute = (MemoryQualityAttribute) space.getQualityAttribute(commentField); + Object[] wordValues = attribute.getDatas(); + + biasRegularization = configuration.getFloat("recommender.bias.regularization", 0.01F); + userRegularization = configuration.getFloat("recommender.user.regularization", 0.01F); + itemRegularization = configuration.getFloat("recommender.item.regularization", 0.01F); + + userFactors = DenseMatrix.valueOf(userSize, factorSize); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + + // TODO 此处需要重构initMean与initStd + initMean = 0.0F; + initStd = 0.1F; + userBiases = DenseVector.valueOf(userSize); + userBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + numberOfWords = 0; + // build review matrix and counting the number of words + contentMatrix = new Int2ObjectRBTreeMap<>(); + Map wordDictionaries = new HashMap<>(); + for (DataInstance sample : model) { + int userIndex = sample.getQualityFeature(userDimension); + int itemIndex = sample.getQualityFeature(itemDimension); + int contentIndex = sample.getQualityFeature(commentDimension); + String data = (String) wordValues[contentIndex]; + String[] words = data.isEmpty() ? new String[0] : data.split(":"); + for (String word : words) { + if (!wordDictionaries.containsKey(word) && StringUtils.isNotEmpty(word)) { + wordDictionaries.put(word, numberOfWords); + numberOfWords++; + } + } + // TODO 此处旧代码使用indexes[index] = + // Integer.valueOf(words[index])似乎有Bug,应该使用indexes[index] = + // wordDictionaries.get(word); + int[] wordIndexes = new int[words.length]; + for (int index = 0; index < words.length; index++) { + wordIndexes[index] = Integer.valueOf(words[index]); + } + Content content = new Content(wordIndexes); + contentMatrix.put(userIndex * itemSize + itemIndex, content); + } + + // TODO 此处保证所有特征都会被识别 + for (Object value : wordValues) { + String content = (String) value; + String[] words = content.split(":"); + for (String word : words) { + if (!wordDictionaries.containsKey(word) && StringUtils.isNotEmpty(word)) { + wordDictionaries.put(word, numberOfWords); + numberOfWords++; + } + } + } + + logger.info("number of users : " + userSize); + logger.info("number of items : " + itemSize); + logger.info("number of words : " + numberOfWords); + + wordFactors = DenseMatrix.valueOf(factorSize, numberOfWords); + wordFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(0.1F)); + }); + + userProbabilities = DenseMatrix.valueOf(userSize, factorSize); + wordProbabilities = DenseMatrix.valueOf(factorSize, numberOfWords); + probability = DenseVector.valueOf(factorSize); + probability.setValues(1F); + + function = new SoftMaxActivationFunction(); + + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); // user + int itemIndex = term.getColumn(); // item + Content content = contentMatrix.get(userIndex * itemSize + itemIndex); + int[] wordIndexes = content.getWordIndexes(); + int[] topicIndexes = new int[wordIndexes.length]; + for (int wordIndex = 0; wordIndex < wordIndexes.length; wordIndex++) { + topicIndexes[wordIndex] = RandomUtility.randomInteger(factorSize); + } + content.setTopicIndexes(topicIndexes); + } + calculateThetas(); + calculatePhis(); + } + + private void sample() { + calculateThetas(); + calculatePhis(); + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); // user + int itemIndex = term.getColumn(); // item + Content content = contentMatrix.get(userIndex * itemSize + itemIndex); + int[] wordIndexes = content.getWordIndexes(); + int[] topicIndexes = content.getTopicIndexes(); + sampleTopicsToWords(userIndex, wordIndexes, topicIndexes); + // LOG.info("user:" + u + ", item:" + j + ", topics:" + s); + } + } + + /** + * Update function for thetas and phiks, check if softmax comes in to NaN and + * update the parameters. + * + * @param oldValues old values of the parameter + * @param newValues new values to update the parameter + * @return the old values if new values contain NaN + * @throws Exception if error occurs + */ + private float[] updateArray(float[] oldValues, float[] newValues) { + for (float value : newValues) { + if (Float.isNaN(value)) { + return oldValues; + } + } + return newValues; + } + + private void calculateThetas() { + for (int userIndex = 0; userIndex < userSize; userIndex++) { + DenseVector factorVector = userFactors.getRowVector(userIndex); + function.forward(factorVector, userProbabilities.getRowVector(userIndex)); + } + } + + private void calculatePhis() { + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + DenseVector factorVector = wordFactors.getRowVector(factorIndex); + function.forward(factorVector, wordProbabilities.getRowVector(factorIndex)); + } + } + + // TODO 考虑整理到Content. + private int[] sampleTopicsToWords(int userIndex, int[] wordsIndexes, int[] topicIndexes) { + for (int wordIndex = 0; wordIndex < wordsIndexes.length; wordIndex++) { + int topicIndex = wordsIndexes[wordIndex]; + DefaultScalar sum = DefaultScalar.getInstance(); + sum.setValue(0F); + probability.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = userProbabilities.getValue(userIndex, index) * wordProbabilities.getValue(index, topicIndex); + sum.shiftValue(value); + scalar.setValue(sum.getValue()); + }); + topicIndexes[wordIndex] = SampleUtility.binarySearch(probability, 0, probability.getElementSize() - 1, RandomUtility.randomFloat(sum.getValue())); + } + return topicIndexes; + } + + /** + * The training approach is SGD instead of L-BFGS, so it can be slow if the + * dataset is big. + */ + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + // SGD training + // TODO 此处应该修改为配置 + for (int iterationSDG = 0; iterationSDG < 5; iterationSDG++) { + totalError = 0F; + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); // user + int itemIndex = term.getColumn(); // item + float score = term.getValue(); + + float predict = predict(userIndex, itemIndex); + float error = score - predict; + totalError += error * error; + + // update factors + float userBias = userBiases.getValue(userIndex); + float userSgd = error - biasRegularization * userBias; + userBiases.shiftValue(userIndex, learnRatio * userSgd); + // loss += regB * bu * bu; + float itemBias = itemBiases.getValue(itemIndex); + float itemSgd = error - biasRegularization * itemBias; + itemBiases.shiftValue(itemIndex, learnRatio * itemSgd); + // loss += regB * bj * bj; + + // TODO 此处应该重构 + Content content = contentMatrix.get(userIndex * itemSize + itemIndex); + int[] wordIndexes = content.getWordIndexes(); + if (wordIndexes.length == 0) { + continue; + } + int[] topicIndexes = content.getTopicIndexes(); + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float itemFactor = itemFactors.getValue(itemIndex, factorIndex); + float userSGD = error * itemFactor - userRegularization * userFactor; + float itemSGD = error * userFactor - itemRegularization * itemFactor; + userFactors.shiftValue(userIndex, factorIndex, learnRatio * userSGD); + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * itemSGD); + for (int wordIndex = 0; wordIndex < wordIndexes.length; wordIndex++) { + int topicIndex = topicIndexes[wordIndex]; + if (factorIndex == topicIndex) { + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (1 - userProbabilities.getValue(userIndex, topicIndex))); + } else { + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (-userProbabilities.getValue(userIndex, topicIndex))); + } + totalError -= MathUtility.logarithm(userProbabilities.getValue(userIndex, topicIndex) * wordProbabilities.getValue(topicIndex, wordIndexes[wordIndex]), 2); + } + } + + for (int wordIndex = 0; wordIndex < wordIndexes.length; wordIndex++) { + int topicIndex = topicIndexes[wordIndex]; + for (int dictionaryIndex = 0; dictionaryIndex < numberOfWords; dictionaryIndex++) { + if (dictionaryIndex == wordIndexes[wordIndex]) { + wordFactors.shiftValue(topicIndex, wordIndexes[wordIndex], learnRatio * (-1 + wordProbabilities.getValue(topicIndex, wordIndexes[wordIndex]))); + } else { + wordFactors.shiftValue(topicIndex, wordIndexes[wordIndex], learnRatio * (wordProbabilities.getValue(topicIndex, wordIndexes[wordIndex]))); + } + } + } + } + totalError *= 0.5F; + } // end of SGDtraining + logger.info(" iter:" + epocheIndex + ", loss:" + totalError); + logger.info(" iter:" + epocheIndex + ", sampling"); + sample(); + logger.info(" iter:" + epocheIndex + ", sample finished"); + } + } + + @Override + protected float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userFactors.getRowVector(userIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + float value = scalar.dotProduct(userVector, itemVector).getValue(); + value += meanScore + userBiases.getValue(userIndex) + itemBiases.getValue(itemIndex); + if (value > maximumScore) { + value = maximumScore; + } else if (value < minimumScore) { + value = minimumScore; + } + return value; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/content/rating/TopicMFATModel.java b/src/main/java/com/jstarcraft/rns/model/content/rating/TopicMFATModel.java new file mode 100644 index 0000000..3cde33d --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/content/rating/TopicMFATModel.java @@ -0,0 +1,281 @@ +package com.jstarcraft.rns.model.content.rating; + +import java.util.HashMap; +import java.util.Map; + +import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Table; +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.attribute.MemoryQualityAttribute; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.ai.model.neuralnetwork.activation.ActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.activation.SoftMaxActivationFunction; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * TopicMF AT推荐器 + * + *
+ * TopicMF: Simultaneously Exploiting Ratings and Reviews for Recommendation
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class TopicMFATModel extends MatrixFactorizationModel { + + protected String commentField; + protected int commentDimension; + protected SparseMatrix W; + protected DenseMatrix documentFactors; + protected DenseMatrix wordFactors; + protected float K1, K2; + protected DenseVector userBiases; + protected DenseVector itemBiases; + // TODO 准备取消,父类已实现. + protected DenseMatrix userFactors; + protected DenseMatrix itemFactors; + // TODO topic似乎就是factor? + protected int numberOfTopics; + protected int numberOfWords; + protected int numberOfDocuments; + + protected float lambda, lambdaU, lambdaV, lambdaB; + + protected Table userItemToDocument; + // TODO 准备取消,父类已实现. + protected float initMean; + protected float initStd; + + protected DenseVector topicVector; + protected ActivationFunction function; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + commentField = configuration.getString("data.model.fields.comment"); + commentDimension = model.getQualityInner(commentField); + MemoryQualityAttribute attribute = (MemoryQualityAttribute) space.getQualityAttribute(commentField); + Object[] documentValues = attribute.getDatas(); + + // init hyper-parameters + lambda = configuration.getFloat("recommender.regularization.lambda", 0.001F); + lambdaU = configuration.getFloat("recommender.regularization.lambdaU", 0.001F); + lambdaV = configuration.getFloat("recommender.regularization.lambdaV", 0.001F); + lambdaB = configuration.getFloat("recommender.regularization.lambdaB", 0.001F); + numberOfTopics = configuration.getInteger("recommender.topic.number", 10); + learnRatio = configuration.getFloat("recommender.iterator.learnrate", 0.01F); + epocheSize = configuration.getInteger("recommender.iterator.maximum", 10); + + numberOfDocuments = scoreMatrix.getElementSize(); + + // count the number of words, build the word dictionary and + // userItemToDoc dictionary + Map wordDictionaries = new HashMap<>(); + Table documentTable = HashBasedTable.create(); + // TODO rowCount改为documentIndex? + int rowCount = 0; + userItemToDocument = HashBasedTable.create(); + for (DataInstance sample : model) { + int userIndex = sample.getQualityFeature(userDimension); + int itemIndex = sample.getQualityFeature(itemDimension); + int documentIndex = sample.getQualityFeature(commentDimension); + userItemToDocument.put(userIndex, itemIndex, rowCount); + // convert wordIds to wordIndices + String data = (String) documentValues[documentIndex]; + String[] words = data.isEmpty() ? new String[0] : data.split(":"); + for (String word : words) { + Integer wordIndex = wordDictionaries.get(word); + if (wordIndex == null) { + wordIndex = numberOfWords++; + wordDictionaries.put(word, wordIndex); + } + Float oldValue = documentTable.get(rowCount, wordIndex); + if (oldValue == null) { + oldValue = 0F; + } + float newValue = oldValue + 1F / words.length; + documentTable.put(rowCount, wordIndex, newValue); + } + rowCount++; + } + // build W + W = SparseMatrix.valueOf(numberOfDocuments, numberOfWords, documentTable); + + userBiases = DenseVector.valueOf(userSize); + userBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + userFactors = DenseMatrix.valueOf(userSize, numberOfTopics); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemFactors = DenseMatrix.valueOf(itemSize, numberOfTopics); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + K1 = initStd; + K2 = initStd; + + topicVector = DenseVector.valueOf(numberOfTopics); + function = new SoftMaxActivationFunction(); + + // init theta and phi + // TODO theta实际是documentFactors + documentFactors = DenseMatrix.valueOf(numberOfDocuments, numberOfTopics); + calculateTheta(); + // TODO phi实际是wordFactors + wordFactors = DenseMatrix.valueOf(numberOfTopics, numberOfWords); + wordFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(0.01F)); + }); + + logger.info("number of users : " + userSize); + logger.info("number of Items : " + itemSize); + logger.info("number of words : " + wordDictionaries.size()); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseMatrix transposeThis = DenseMatrix.valueOf(numberOfTopics, numberOfTopics); + DenseMatrix thetaW = DenseMatrix.valueOf(numberOfTopics, numberOfWords); + DenseMatrix thetaPhi = DenseMatrix.valueOf(numberOfTopics, numberOfWords); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + float wordLoss = 0F; + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); // userIdx + int itemIndex = term.getColumn(); // itemIdx + int documentIndex = userItemToDocument.get(userIndex, itemIndex); + float y_true = term.getValue(); + float y_pred = predict(userIndex, itemIndex); + + float error = y_true - y_pred; + totalError += error * error; + + // update user item biases + float userBiasValue = userBiases.getValue(userIndex); + userBiases.shiftValue(userIndex, learnRatio * (error - lambdaB * userBiasValue)); + totalError += lambdaB * userBiasValue * userBiasValue; + + float itemBiasValue = itemBiases.getValue(itemIndex); + itemBiases.shiftValue(itemIndex, learnRatio * (error - lambdaB * itemBiasValue)); + totalError += lambdaB * itemBiasValue * itemBiasValue; + + // update user item factors + for (int factorIndex = 0; factorIndex < numberOfTopics; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float itemFactor = itemFactors.getValue(itemIndex, factorIndex); + + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (error * itemFactor - lambdaU * userFactor)); + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * (error * userFactor - lambdaV * itemFactor)); + totalError += lambdaU * userFactor * userFactor + lambdaV * itemFactor * itemFactor; + + SparseVector documentVector = W.getRowVector(documentIndex); + for (VectorScalar documentTerm : documentVector) { + int wordIndex = documentTerm.getIndex(); + float w_pred = scalar.dotProduct(documentFactors.getRowVector(documentIndex), wordFactors.getColumnVector(wordIndex)).getValue(); + float w_true = documentTerm.getValue(); + float w_error = w_true - w_pred; + wordLoss += w_error; + + float derivative = 0F; + for (int topicIndex = 0; topicIndex < numberOfTopics; topicIndex++) { + if (factorIndex == topicIndex) { + derivative += w_error * wordFactors.getValue(topicIndex, wordIndex) * documentFactors.getValue(documentIndex, topicIndex) * (1 - documentFactors.getValue(documentIndex, topicIndex)); + } else { + derivative += w_error * wordFactors.getValue(topicIndex, wordIndex) * documentFactors.getValue(documentIndex, topicIndex) * (-documentFactors.getValue(documentIndex, factorIndex)); + } + // update K1 K2 + K1 += learnRatio * lambda * w_error * wordFactors.getValue(topicIndex, wordIndex) * documentFactors.getValue(documentIndex, topicIndex) * (1 - documentFactors.getValue(documentIndex, topicIndex)) * Math.abs(userFactors.getValue(userIndex, topicIndex)); + K2 += learnRatio * lambda * w_error * wordFactors.getValue(topicIndex, wordIndex) * documentFactors.getValue(documentIndex, topicIndex) * (1 - documentFactors.getValue(documentIndex, topicIndex)) * Math.abs(itemFactors.getValue(itemIndex, topicIndex)); + } + userFactors.shiftValue(userIndex, factorIndex, learnRatio * K1 * derivative); + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * K2 * derivative); + } + } + } + // calculate theta + logger.info(" iter:" + epocheIndex + ", finish factors update"); + + // calculate wordLoss and loss + wordLoss = wordLoss / numberOfTopics; + totalError += wordLoss; + totalError *= 0.5F; + logger.info(" iter:" + epocheIndex + ", loss:" + totalError + ", wordLoss:" + wordLoss / 2F); + + calculateTheta(); + logger.info(" iter:" + epocheIndex + ", finish theta update"); + + // update phi by NMF + // TODO 此处操作可以整合 + thetaW.dotProduct(documentFactors, true, W, false, MathCalculator.SERIAL); + transposeThis.dotProduct(documentFactors, true, documentFactors, false, MathCalculator.SERIAL); + thetaPhi.dotProduct(transposeThis, false, wordFactors, false, MathCalculator.SERIAL); + for (int topicIndex = 0; topicIndex < numberOfTopics; topicIndex++) { + for (int wordIndex = 0; wordIndex < numberOfWords; wordIndex++) { + float numerator = wordFactors.getValue(topicIndex, wordIndex) * thetaW.getValue(topicIndex, wordIndex); + float denominator = thetaPhi.getValue(topicIndex, wordIndex); + wordFactors.setValue(topicIndex, wordIndex, numerator / denominator); + } + } + logger.info(" iter:" + epocheIndex + ", finish phi update"); + } + } + + @Override + protected float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + float value = meanScore + userBiases.getValue(userIndex) + itemBiases.getValue(itemIndex); + value += scalar.dotProduct(userFactors.getRowVector(userIndex), itemFactors.getRowVector(itemIndex)).getValue(); + return value; + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex)); + } + + /** + * Calculate theta vectors via userFactors and itemFactors. thetaVector = + * softmax( exp(K1|u| + K2|v|) ) + */ + private void calculateTheta() { + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + int documentIdx = userItemToDocument.get(userIndex, itemIndex); + DenseVector documentVector = documentFactors.getRowVector(documentIdx); + topicVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + value = Math.abs(userFactors.getValue(userIndex, index)) * K1 + Math.abs(itemFactors.getValue(itemIndex, index)) * K2; + scalar.setValue(value); + }); + function.forward(topicVector, documentVector); + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/content/rating/TopicMFMTModel.java b/src/main/java/com/jstarcraft/rns/model/content/rating/TopicMFMTModel.java new file mode 100644 index 0000000..99d5195 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/content/rating/TopicMFMTModel.java @@ -0,0 +1,277 @@ +package com.jstarcraft.rns.model.content.rating; + +import java.util.HashMap; +import java.util.Map; + +import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Table; +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.attribute.MemoryQualityAttribute; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.ai.model.neuralnetwork.activation.ActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.activation.SoftMaxActivationFunction; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; + +/** + * + * TopicMF MT推荐器 + * + *
+ * TopicMF: Simultaneously Exploiting Ratings and Reviews for Recommendation
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class TopicMFMTModel extends MatrixFactorizationModel { + + protected String commentField; + protected int commentDimension; + protected SparseMatrix W; + protected DenseMatrix documentFactors; + protected DenseMatrix wordFactors; + protected float K; + protected DenseVector userBiases; + protected DenseVector itemBiases; + // TODO 准备取消,父类已实现. + protected DenseMatrix userFactors; + protected DenseMatrix itemFactors; + // TODO topic似乎就是factor? + protected int numberOfTopics; + protected int numberOfWords; + protected int numberOfDocuments; + + protected float lambda, lambdaU, lambdaV, lambdaB; + + protected Table userItemToDocument; + // TODO 准备取消,父类已实现. + protected float initMean; + protected float initStd; + + protected DenseVector topicVector; + protected ActivationFunction function; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + + commentField = configuration.getString("data.model.fields.comment"); + commentDimension = model.getQualityInner(commentField); + MemoryQualityAttribute attribute = (MemoryQualityAttribute) space.getQualityAttribute(commentField); + Object[] documentValues = attribute.getDatas(); + + // init hyper-parameters + lambda = configuration.getFloat("recommender.regularization.lambda", 0.001F); + lambdaU = configuration.getFloat("recommender.regularization.lambdaU", 0.001F); + lambdaV = configuration.getFloat("recommender.regularization.lambdaV", 0.001F); + lambdaB = configuration.getFloat("recommender.regularization.lambdaB", 0.001F); + numberOfTopics = configuration.getInteger("recommender.topic.number", 10); + learnRatio = configuration.getFloat("recommender.iterator.learnrate", 0.01F); + epocheSize = configuration.getInteger("recommender.iterator.maximum", 10); + + numberOfDocuments = scoreMatrix.getElementSize(); + + // count the number of words, build the word dictionary and + // userItemToDoc dictionary + Map wordDictionaries = new HashMap<>(); + Table documentTable = HashBasedTable.create(); + int rowCount = 0; + userItemToDocument = HashBasedTable.create(); + for (DataInstance sample : model) { + int userIndex = sample.getQualityFeature(userDimension); + int itemIndex = sample.getQualityFeature(itemDimension); + int documentIndex = sample.getQualityFeature(commentDimension); + userItemToDocument.put(userIndex, itemIndex, rowCount); + // convert wordIds to wordIndices + String data = (String) documentValues[documentIndex]; + String[] words = data.isEmpty() ? new String[0] : data.split(":"); + for (String word : words) { + Integer wordIndex = wordDictionaries.get(word); + if (wordIndex == null) { + wordIndex = numberOfWords++; + wordDictionaries.put(word, wordIndex); + } + Float oldValue = documentTable.get(rowCount, wordIndex); + if (oldValue == null) { + oldValue = 0F; + } + float newValue = oldValue + 1F / words.length; + documentTable.put(rowCount, wordIndex, newValue); + } + rowCount++; + } + // build W + W = SparseMatrix.valueOf(numberOfDocuments, numberOfWords, documentTable); + + userBiases = DenseVector.valueOf(userSize); + userBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + userFactors = DenseMatrix.valueOf(userSize, numberOfTopics); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemFactors = DenseMatrix.valueOf(itemSize, numberOfTopics); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + K = initStd; + + topicVector = DenseVector.valueOf(numberOfTopics); + function = new SoftMaxActivationFunction(); + + // init theta and phi + // TODO theta实际是documentFactors + documentFactors = DenseMatrix.valueOf(numberOfDocuments, numberOfTopics); + calculateTheta(); + // TODO phi实际是wordFactors + wordFactors = DenseMatrix.valueOf(numberOfTopics, numberOfWords); + wordFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(0.01F)); + }); + + logger.info("number of users : " + userSize); + logger.info("number of Items : " + itemSize); + logger.info("number of words : " + wordDictionaries.size()); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseMatrix transposeThis = DenseMatrix.valueOf(numberOfTopics, numberOfTopics); + DenseMatrix thetaW = DenseMatrix.valueOf(numberOfTopics, numberOfWords); + DenseMatrix thetaPhi = DenseMatrix.valueOf(numberOfTopics, numberOfWords); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + float wordLoss = 0F; + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); // userIdx + int itemIndex = term.getColumn(); // itemIdx + int documentIndex = userItemToDocument.get(userIndex, itemIndex); + float y_true = term.getValue(); + float y_pred = predict(userIndex, itemIndex); + + float error = y_true - y_pred; + totalError += error * error; + + // update user item biases + float userBiasValue = userBiases.getValue(userIndex); + userBiases.shiftValue(userIndex, learnRatio * (error - lambdaB * userBiasValue)); + totalError += lambdaB * userBiasValue * userBiasValue; + + float itemBiasValue = itemBiases.getValue(itemIndex); + itemBiases.shiftValue(itemIndex, learnRatio * (error - lambdaB * itemBiasValue)); + totalError += lambdaB * itemBiasValue * itemBiasValue; + + // update user item factors + for (int factorIndex = 0; factorIndex < numberOfTopics; factorIndex++) { + float userFactorValue = userFactors.getValue(userIndex, factorIndex); + float itemFactorValue = itemFactors.getValue(itemIndex, factorIndex); + + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (error * itemFactorValue - lambdaU * userFactorValue)); + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * (error * userFactorValue - lambdaV * itemFactorValue)); + totalError += lambdaU * userFactorValue * userFactorValue + lambdaV * itemFactorValue * itemFactorValue; + + SparseVector documentVector = W.getRowVector(documentIndex); + for (VectorScalar documentTerm : documentVector) { + int wordIndex = documentTerm.getIndex(); + float w_pred = scalar.dotProduct(documentFactors.getRowVector(documentIndex), wordFactors.getColumnVector(wordIndex)).getValue(); + float w_true = documentTerm.getValue(); + float w_error = w_true - w_pred; + wordLoss += w_error; + + float derivative = 0F; + for (int topicIndex = 0; topicIndex < numberOfTopics; topicIndex++) { + if (factorIndex == topicIndex) { + derivative += w_error * wordFactors.getValue(topicIndex, wordIndex) * documentFactors.getValue(documentIndex, topicIndex) * (1 - documentFactors.getValue(documentIndex, topicIndex)); + } else { + derivative += w_error * wordFactors.getValue(topicIndex, wordIndex) * documentFactors.getValue(documentIndex, topicIndex) * (-documentFactors.getValue(documentIndex, factorIndex)); + } + // update K1 K2 + K += learnRatio * lambda * w_error * wordFactors.getValue(topicIndex, wordIndex) * documentFactors.getValue(documentIndex, topicIndex) * (1 - documentFactors.getValue(documentIndex, topicIndex)) * Math.abs(userFactors.getValue(userIndex, topicIndex)); + } + userFactors.shiftValue(userIndex, factorIndex, learnRatio * K * derivative * itemFactors.getValue(itemIndex, factorIndex)); + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * K * derivative * userFactors.getValue(userIndex, factorIndex)); + } + } + } + // calculate theta + logger.info(" iter:" + epocheIndex + ", finish factors update"); + + // calculate wordLoss and loss + wordLoss = wordLoss / numberOfTopics; + totalError += wordLoss; + totalError *= 0.5F; + logger.info(" iter:" + epocheIndex + ", loss:" + totalError + ", wordLoss:" + wordLoss / 2F); + + calculateTheta(); + logger.info(" iter:" + epocheIndex + ", finish theta update"); + + // update phi by NMF + // TODO 此处操作可以整合 + thetaW.dotProduct(documentFactors, true, W, false, MathCalculator.SERIAL); + transposeThis.dotProduct(documentFactors, true, documentFactors, false, MathCalculator.SERIAL); + thetaPhi.dotProduct(transposeThis, false, wordFactors, false, MathCalculator.SERIAL); + for (int topicIndex = 0; topicIndex < numberOfTopics; topicIndex++) { + for (int wordIndex = 0; wordIndex < numberOfWords; wordIndex++) { + float numerator = wordFactors.getValue(topicIndex, wordIndex) * thetaW.getValue(topicIndex, wordIndex); + float denominator = thetaPhi.getValue(topicIndex, wordIndex); + wordFactors.setValue(topicIndex, wordIndex, numerator / denominator); + } + } + logger.info(" iter:" + epocheIndex + ", finish phi update"); + } + } + + @Override + protected float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + float value = meanScore + userBiases.getValue(userIndex) + itemBiases.getValue(itemIndex); + value += scalar.dotProduct(userFactors.getRowVector(userIndex), itemFactors.getRowVector(itemIndex)).getValue(); + return value; + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex)); + } + + /** + * Calculate theta vectors via userFactors and itemFactors. thetaVector = + * softmax( exp(K|u||v|) ) + */ + private void calculateTheta() { + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + int documentIdx = userItemToDocument.get(userIndex, itemIndex); + DenseVector documentVector = documentFactors.getRowVector(documentIdx); + topicVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + value = K * Math.abs(userFactors.getValue(userIndex, index)) * Math.abs(itemFactors.getValue(itemIndex, index)); + scalar.setValue(value); + }); + function.forward(topicVector, documentVector); + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/context/context.txt b/src/main/java/com/jstarcraft/rns/model/context/context.txt new file mode 100644 index 0000000..7acfbda --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/context.txt @@ -0,0 +1,6 @@ +http://www.ntu.edu.sg/home/gaocong/ + +http://www.vldb.org/pvldb/vol10/p1010-liu.pdf + +十二种POI算法. +http://spatialkeyword.sce.ntu.edu.sg/eval-vldb17/ \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/context/ranking/RankGeoFMModel.java b/src/main/java/com/jstarcraft/rns/model/context/ranking/RankGeoFMModel.java new file mode 100644 index 0000000..1937723 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/ranking/RankGeoFMModel.java @@ -0,0 +1,327 @@ +package com.jstarcraft.rns.model.context.ranking; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.ai.math.structure.vector.ArrayVector; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.Float2FloatKeyValue; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.MatrixFactorizationModel; +import com.jstarcraft.rns.model.exception.ModelException; +import com.jstarcraft.rns.utility.LogisticUtility; + +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +/** + * + * Rank GeoFM推荐器 + * + *
+ * Rank-GeoFM: A ranking based geographical factorization method for point of interest recommendation
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class RankGeoFMModel extends MatrixFactorizationModel { + + protected DenseMatrix explicitUserFactors, implicitUserFactors, itemFactors; + + protected ArrayVector[] neighborWeights; + + protected float margin, radius, balance; + + protected DenseVector E; + + protected DenseMatrix geoInfluences; + + protected int knn; + + protected Float2FloatKeyValue[] itemLocations; + + private String longitudeField, latitudeField; + + private int longitudeDimension, latitudeDimension; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + margin = configuration.getFloat("recommender.ranking.margin", 0.3F); + radius = configuration.getFloat("recommender.regularization.radius", 1F); + balance = configuration.getFloat("recommender.regularization.balance", 0.2F); + knn = configuration.getInteger("recommender.item.nearest.neighbour.number", 300); + + longitudeField = configuration.getString("data.model.fields.longitude"); + latitudeField = configuration.getString("data.model.fields.latitude"); + + DataModule locationModel = space.getModule("location"); + longitudeDimension = locationModel.getQuantityInner(longitudeField); + latitudeDimension = locationModel.getQuantityInner(latitudeField); + + geoInfluences = DenseMatrix.valueOf(itemSize, factorSize); + + explicitUserFactors = DenseMatrix.valueOf(userSize, factorSize); + explicitUserFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + implicitUserFactors = DenseMatrix.valueOf(userSize, factorSize); + implicitUserFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + itemLocations = new Float2FloatKeyValue[itemSize]; + + int itemDimension = locationModel.getQualityInner(itemField); + for (DataInstance instance : locationModel) { + int itemIndex = instance.getQualityFeature(itemDimension); + Float2FloatKeyValue itemLocation = new Float2FloatKeyValue(instance.getQuantityFeature(longitudeDimension), instance.getQuantityFeature(latitudeDimension)); + itemLocations[itemIndex] = itemLocation; + } + calculateNeighborWeightMatrix(knn); + + E = DenseVector.valueOf(itemSize + 1); + E.setValue(1, 1F); + for (int itemIndex = 2; itemIndex <= itemSize; itemIndex++) { + E.setValue(itemIndex, E.getValue(itemIndex - 1) + 1F / itemIndex); + } + + geoInfluences = DenseMatrix.valueOf(itemSize, factorSize); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseMatrix explicitUserDeltas = DenseMatrix.valueOf(explicitUserFactors.getRowSize(), explicitUserFactors.getColumnSize()); + DenseMatrix implicitUserDeltas = DenseMatrix.valueOf(implicitUserFactors.getRowSize(), implicitUserFactors.getColumnSize()); + DenseMatrix itemDeltas = DenseMatrix.valueOf(itemFactors.getRowSize(), itemFactors.getColumnSize()); + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + calculateGeoInfluenceMatrix(); + + totalError = 0F; + explicitUserDeltas.iterateElement(MathCalculator.PARALLEL, (element) -> { + element.setValue(explicitUserFactors.getValue(element.getRow(), element.getColumn())); + }); + implicitUserDeltas.iterateElement(MathCalculator.PARALLEL, (element) -> { + element.setValue(implicitUserFactors.getValue(element.getRow(), element.getColumn())); + }); + itemDeltas.iterateElement(MathCalculator.PARALLEL, (element) -> { + element.setValue(itemFactors.getValue(element.getRow(), element.getColumn())); + }); + + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + for (VectorScalar term : userVector) { + int positiveItemIndex = term.getIndex(); + + int sampleCount = 0; + float positiveScore = scalar.dotProduct(explicitUserDeltas.getRowVector(userIndex), itemDeltas.getRowVector(positiveItemIndex)).getValue() + scalar.dotProduct(implicitUserDeltas.getRowVector(userIndex), geoInfluences.getRowVector(positiveItemIndex)).getValue(); + float positiveValue = term.getValue(); + + int negativeItemIndex; + float negativeScore; + float negativeValue; + + while (true) { + negativeItemIndex = RandomUtility.randomInteger(itemSize); + negativeScore = scalar.dotProduct(explicitUserDeltas.getRowVector(userIndex), itemDeltas.getRowVector(negativeItemIndex)).getValue() + scalar.dotProduct(implicitUserDeltas.getRowVector(userIndex), geoInfluences.getRowVector(negativeItemIndex)).getValue(); + negativeValue = 0F; + for (VectorScalar rateTerm : userVector) { + if (rateTerm.getIndex() == negativeItemIndex) { + negativeValue = rateTerm.getValue(); + } + } + + sampleCount++; + if ((indicator(positiveValue, negativeValue) && indicator(negativeScore + margin, positiveScore)) || sampleCount > itemSize) { + break; + } + } + + if (indicator(positiveValue, negativeValue) && indicator(negativeScore + margin, positiveScore)) { + int sampleIndex = itemSize / sampleCount; + + float s = LogisticUtility.getValue(negativeScore + margin - positiveScore); + totalError += E.getValue(sampleIndex) * s; + + float uij = s * (1 - s); + float error = E.getValue(sampleIndex) * uij * learnRatio; + DenseVector positiveItemVector = itemFactors.getRowVector(positiveItemIndex); + DenseVector negativeItemVector = itemFactors.getRowVector(negativeItemIndex); + DenseVector explicitUserVector = explicitUserFactors.getRowVector(userIndex); + + DenseVector positiveGeoVector = geoInfluences.getRowVector(positiveItemIndex); + DenseVector negativeGeoVector = geoInfluences.getRowVector(negativeItemIndex); + DenseVector implicitUserVector = implicitUserFactors.getRowVector(userIndex); + + // TODO 可以并发计算 + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + explicitUserVector.setValue(factorIndex, explicitUserVector.getValue(factorIndex) - (negativeItemVector.getValue(factorIndex) - positiveItemVector.getValue(factorIndex)) * error); + implicitUserVector.setValue(factorIndex, implicitUserVector.getValue(factorIndex) - (negativeGeoVector.getValue(factorIndex) - positiveGeoVector.getValue(factorIndex)) * error); + } + // TODO 可以并发计算 + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float itemDelta = explicitUserVector.getValue(factorIndex) * error; + positiveItemVector.setValue(factorIndex, positiveItemVector.getValue(factorIndex) + itemDelta); + negativeItemVector.setValue(factorIndex, negativeItemVector.getValue(factorIndex) - itemDelta); + } + + float explicitUserDelta = explicitUserVector.getNorm(2, true); + if (explicitUserDelta > radius) { + explicitUserDelta = radius / explicitUserDelta; + } else { + explicitUserDelta = 1F; + } + float implicitUserDelta = implicitUserVector.getNorm(2F, true); + if (implicitUserDelta > balance * radius) { + implicitUserDelta = balance * radius / implicitUserDelta; + } else { + implicitUserDelta = 1F; + } + float positiveItemDelta = positiveItemVector.getNorm(2, true); + if (positiveItemDelta > radius) { + positiveItemDelta = radius / positiveItemDelta; + } else { + positiveItemDelta = 1F; + } + float negativeItemDelta = negativeItemVector.getNorm(2, true); + if (negativeItemDelta > radius) { + negativeItemDelta = radius / negativeItemDelta; + } else { + negativeItemDelta = 1F; + } + // TODO 可以并发计算 + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + if (explicitUserDelta != 1F) { + explicitUserVector.setValue(factorIndex, explicitUserVector.getValue(factorIndex) * explicitUserDelta); + } + if (implicitUserDelta != 1F) { + implicitUserVector.setValue(factorIndex, implicitUserVector.getValue(factorIndex) * implicitUserDelta); + } + if (positiveItemDelta != 1F) { + positiveItemVector.setValue(factorIndex, positiveItemVector.getValue(factorIndex) * positiveItemDelta); + } + if (negativeItemDelta != 1F) { + negativeItemVector.setValue(factorIndex, negativeItemVector.getValue(factorIndex) * negativeItemDelta); + } + } + } + } + } + + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + /** + * @param k_nearest + * @return + */ + private void calculateNeighborWeightMatrix(Integer k_nearest) { + HashMatrix dataTable = new HashMatrix(true, itemSize, itemSize, new Long2FloatRBTreeMap()); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + List> locationNeighbors = new ArrayList<>(itemSize); + Float2FloatKeyValue location = itemLocations[itemIndex]; + for (int neighborIndex = 0; neighborIndex < itemSize; neighborIndex++) { + if (itemIndex != neighborIndex) { + Float2FloatKeyValue neighborLocation = itemLocations[neighborIndex]; + float distance = getDistance(location.getKey(), location.getValue(), neighborLocation.getKey(), neighborLocation.getValue()); + locationNeighbors.add(new KeyValue<>(neighborIndex, distance)); + } + } + Collections.sort(locationNeighbors, (left, right) -> { + // 升序 + return left.getValue().compareTo(right.getValue()); + }); + locationNeighbors = locationNeighbors.subList(0, k_nearest); + + for (int index = 0; index < locationNeighbors.size(); index++) { + int neighborItemIdx = locationNeighbors.get(index).getKey(); + float weight; + if (locationNeighbors.get(index).getValue() < 0.5F) { + weight = 1F / 0.5F; + } else { + weight = 1F / (locationNeighbors.get(index).getValue()); + } + dataTable.setValue(itemIndex, neighborItemIdx, weight); + } + } + + SparseMatrix matrix = SparseMatrix.valueOf(itemSize, itemSize, dataTable); + neighborWeights = new ArrayVector[itemSize]; + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + ArrayVector neighborVector = new ArrayVector(matrix.getRowVector(itemIndex)); + neighborVector.scaleValues(1F / neighborVector.getSum(false)); + neighborWeights[itemIndex] = neighborVector; + } + } + + private void calculateGeoInfluenceMatrix() throws ModelException { + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + ArrayVector neighborVector = neighborWeights[itemIndex]; + if (neighborVector.getElementSize() == 0) { + continue; + } + DenseVector geoVector = geoInfluences.getRowVector(itemIndex); + geoVector.setValues(0F); + for (VectorScalar term : neighborVector) { + DenseVector itemVector = itemFactors.getRowVector(term.getIndex()); + geoVector.iterateElement(MathCalculator.SERIAL, (scalar) -> { + int index = scalar.getIndex(); + float value = scalar.getValue(); + scalar.setValue(value + itemVector.getValue(index) * term.getValue()); + }); + } + } + } + + private float getDistance(float leftLatitude, float leftLongitude, float rightLatitude, float rightLongitude) { + float radius = 6378137F; + leftLatitude = (float) (leftLatitude * Math.PI / 180F); + rightLatitude = (float) (rightLatitude * Math.PI / 180F); + float latitude = leftLatitude - rightLatitude; + float longitude = (float) ((leftLongitude - rightLongitude) * Math.PI / 180F); + latitude = (float) Math.sin(latitude / 2F); + longitude = (float) Math.sin(longitude / 2F); + float distance = (float) (2F * radius * Math.asin(Math.sqrt(latitude * latitude + Math.cos(leftLatitude) * Math.cos(rightLatitude) * longitude * longitude))); + return distance / 1000F; + } + + private boolean indicator(double left, double right) { + return left > right; + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + float value = scalar.dotProduct(explicitUserFactors.getRowVector(userIndex), itemFactors.getRowVector(itemIndex)).getValue(); + value += scalar.dotProduct(implicitUserFactors.getRowVector(userIndex), geoInfluences.getRowVector(itemIndex)).getValue(); + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/context/ranking/SBPRModel.java b/src/main/java/com/jstarcraft/rns/model/context/ranking/SBPRModel.java new file mode 100644 index 0000000..b918ae5 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/ranking/SBPRModel.java @@ -0,0 +1,216 @@ +package com.jstarcraft.rns.model.context.ranking; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.SocialModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +import it.unimi.dsi.fastutil.ints.IntSet; + +/** + * + * SBPR推荐器 + * + *
+ * Social Bayesian Personalized Ranking (SBPR)
+ * Leveraging Social Connections to Improve Personalized Ranking for Collaborative Filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +// TODO 仍需重构 +public class SBPRModel extends SocialModel { + /** + * items biases vector + */ + private DenseVector itemBiases; + + /** + * bias regularization + */ + protected float regBias; + + /** + * find items rated by trusted neighbors only + */ + // TODO 考虑重构为List + private List> socialItemList; + + private List userItemSet; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + regBias = configuration.getFloat("recommender.bias.regularization", 0.01F); + // cacheSpec = conf.get("guava.cache.spec", + // "maximumSize=5000,expireAfterAccess=50m"); + + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + + userItemSet = getUserItemSet(scoreMatrix); + + // TODO 考虑重构 + // find items rated by trusted neighbors only + socialItemList = new ArrayList<>(userSize); + + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + IntSet itemSet = userItemSet.get(userIndex); + // find items rated by trusted neighbors only + + SparseVector socialVector = socialMatrix.getRowVector(userIndex); + List socialList = new LinkedList<>(); + for (VectorScalar term : socialVector) { + int socialIndex = term.getIndex(); + userVector = scoreMatrix.getRowVector(socialIndex); + for (VectorScalar enrty : userVector) { + int itemIndex = enrty.getIndex(); + // v's rated items + if (!itemSet.contains(itemIndex) && !socialList.contains(itemIndex)) { + socialList.add(itemIndex); + } + } + } + socialItemList.add(new ArrayList<>(socialList)); + } + } + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int sampleIndex = 0, sampleTimes = userSize * 100; sampleIndex < sampleTimes; sampleIndex++) { + // uniformly draw (userIdx, posItemIdx, k, negItemIdx) + int userIndex, positiveItemIndex, negativeItemIndex; + // userIdx + SparseVector userVector; + do { + userIndex = RandomUtility.randomInteger(userSize); + userVector = scoreMatrix.getRowVector(userIndex); + } while (userVector.getElementSize() == 0); + + // positive item index + positiveItemIndex = userVector.getIndex(RandomUtility.randomInteger(userVector.getElementSize())); + float positiveScore = predict(userIndex, positiveItemIndex); + + // social Items List + // TODO 应该修改为IntSet合适点. + List socialList = socialItemList.get(userIndex); + IntSet itemSet = userItemSet.get(userIndex); + do { + negativeItemIndex = RandomUtility.randomInteger(itemSize); + } while (itemSet.contains(negativeItemIndex) || socialList.contains(negativeItemIndex)); + float negativeScore = predict(userIndex, negativeItemIndex); + + if (socialList.size() > 0) { + // if having social neighbors + int itemIndex = socialList.get(RandomUtility.randomInteger(socialList.size())); + float socialScore = predict(userIndex, itemIndex); + SparseVector socialVector = socialMatrix.getRowVector(userIndex); + float socialWeight = 0F; + for (VectorScalar term : socialVector) { + int socialIndex = term.getIndex(); + itemSet = userItemSet.get(socialIndex); + if (itemSet.contains(itemIndex)) { + socialWeight += 1; + } + } + float positiveError = (positiveScore - socialScore) / (1 + socialWeight); + float negativeError = socialScore - negativeScore; + float positiveGradient = LogisticUtility.getValue(-positiveError), negativeGradient = LogisticUtility.getValue(-negativeError); + float error = (float) (-Math.log(1 - positiveGradient) - Math.log(1 - negativeGradient)); + totalError += error; + + // update bi, bk, bj + float positiveBias = itemBiases.getValue(positiveItemIndex); + itemBiases.shiftValue(positiveItemIndex, learnRatio * (positiveGradient / (1F + socialWeight) - regBias * positiveBias)); + totalError += regBias * positiveBias * positiveBias; + float socialBias = itemBiases.getValue(itemIndex); + itemBiases.shiftValue(itemIndex, learnRatio * (-positiveGradient / (1F + socialWeight) + negativeGradient - regBias * socialBias)); + totalError += regBias * socialBias * socialBias; + float negativeBias = itemBiases.getValue(negativeItemIndex); + itemBiases.shiftValue(negativeItemIndex, learnRatio * (-negativeGradient - regBias * negativeBias)); + totalError += regBias * negativeBias * negativeBias; + + // update P, Q + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float positiveFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float itemFactor = itemFactors.getValue(itemIndex, factorIndex); + float negativeFactor = itemFactors.getValue(negativeItemIndex, factorIndex); + float delta = positiveGradient * (positiveFactor - itemFactor) / (1F + socialWeight) + negativeGradient * (itemFactor - negativeFactor); + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (delta - userRegularization * userFactor)); + itemFactors.shiftValue(positiveItemIndex, factorIndex, learnRatio * (positiveGradient * userFactor / (1F + socialWeight) - itemRegularization * positiveFactor)); + itemFactors.shiftValue(negativeItemIndex, factorIndex, learnRatio * (negativeGradient * (-userFactor) - itemRegularization * negativeFactor)); + delta = positiveGradient * (-userFactor / (1F + socialWeight)) + negativeGradient * userFactor; + itemFactors.shiftValue(itemIndex, factorIndex, learnRatio * (delta - itemRegularization * itemFactor)); + totalError += userRegularization * userFactor * userFactor + itemRegularization * positiveFactor * positiveFactor + itemRegularization * negativeFactor * negativeFactor + itemRegularization * itemFactor * itemFactor; + } + } else { + // if no social neighbors, the same as BPR + float error = positiveScore - negativeScore; + totalError += error; + float gradient = LogisticUtility.getValue(-error); + + // update bi, bj + float positiveBias = itemBiases.getValue(positiveItemIndex); + itemBiases.shiftValue(positiveItemIndex, learnRatio * (gradient - regBias * positiveBias)); + totalError += regBias * positiveBias * positiveBias; + float negativeBias = itemBiases.getValue(negativeItemIndex); + itemBiases.shiftValue(negativeItemIndex, learnRatio * (-gradient - regBias * negativeBias)); + totalError += regBias * negativeBias * negativeBias; + + // update user factors, item factors + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float positiveFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float negItemFactorValue = itemFactors.getValue(negativeItemIndex, factorIndex); + userFactors.shiftValue(userIndex, factorIndex, learnRatio * (gradient * (positiveFactor - negItemFactorValue) - userRegularization * userFactor)); + itemFactors.shiftValue(positiveItemIndex, factorIndex, learnRatio * (gradient * userFactor - itemRegularization * positiveFactor)); + itemFactors.shiftValue(negativeItemIndex, factorIndex, learnRatio * (gradient * (-userFactor) - itemRegularization * negItemFactorValue)); + totalError += userRegularization * userFactor * userFactor + itemRegularization * positiveFactor * positiveFactor + itemRegularization * negItemFactorValue * negItemFactorValue; + } + } + } + + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + @Override + protected float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userFactors.getRowVector(userIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + return itemBiases.getValue(itemIndex) + scalar.dotProduct(userVector, itemVector).getValue(); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(predict(userIndex, itemIndex)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/context/rating/RSTEModel.java b/src/main/java/com/jstarcraft/rns/model/context/rating/RSTEModel.java new file mode 100644 index 0000000..b07c18a --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/rating/RSTEModel.java @@ -0,0 +1,169 @@ +package com.jstarcraft.rns.model.context.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.SocialModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +/** + * + * RSTE推荐器 + * + *
+ * Learning to Recommend with Social Trust Ensemble
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class RSTEModel extends SocialModel { + private float userSocialRatio; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + userFactors = DenseMatrix.valueOf(userSize, factorSize); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + userSocialRatio = configuration.getFloat("recommender.user.social.ratio", 0.8F); + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector socialFactors = DenseVector.valueOf(factorSize); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + DenseMatrix userDeltas = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix itemDeltas = DenseMatrix.valueOf(itemSize, factorSize); + + // ratings + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector socialVector = socialMatrix.getRowVector(userIndex); + float socialWeight = 0F; + socialFactors.setValues(0F); + for (VectorScalar socialTerm : socialVector) { + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + socialFactors.setValue(factorIndex, socialFactors.getValue(factorIndex) + socialTerm.getValue() * userFactors.getValue(socialTerm.getIndex(), factorIndex)); + } + socialWeight += socialTerm.getValue(); + } + DenseVector userVector = userFactors.getRowVector(userIndex); + for (VectorScalar rateTerm : scoreMatrix.getRowVector(userIndex)) { + int itemIndex = rateTerm.getIndex(); + float score = rateTerm.getValue(); + score = (score - minimumScore) / (maximumScore - minimumScore); + // compute directly to speed up calculation + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + float predict = scalar.dotProduct(userVector, itemVector).getValue(); + float sum = 0F; + for (VectorScalar socialTerm : socialVector) { + sum += socialTerm.getValue() * scalar.dotProduct(userFactors.getRowVector(socialTerm.getIndex()), itemVector).getValue(); + } + predict = userSocialRatio * predict + (1F - userSocialRatio) * (socialWeight > 0F ? sum / socialWeight : 0F); + float error = LogisticUtility.getValue(predict) - score; + totalError += error * error; + error = LogisticUtility.getGradient(predict) * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float itemFactor = itemFactors.getValue(itemIndex, factorIndex); + float userDelta = userSocialRatio * error * itemFactor + userRegularization * userFactor; + float socialFactor = socialWeight > 0 ? socialFactors.getValue(factorIndex) / socialWeight : 0; + float itemDelta = error * (userSocialRatio * userFactor + (1 - userSocialRatio) * socialFactor) + itemRegularization * itemFactor; + userDeltas.shiftValue(userIndex, factorIndex, userDelta); + itemDeltas.shiftValue(itemIndex, factorIndex, itemDelta); + totalError += userRegularization * userFactor * userFactor + itemRegularization * itemFactor * itemFactor; + } + } + } + + // social + for (int trusterIndex = 0; trusterIndex < userSize; trusterIndex++) { + SparseVector trusterVector = socialMatrix.getColumnVector(trusterIndex); + for (VectorScalar term : trusterVector) { + int trusteeIndex = term.getIndex(); + SparseVector trusteeVector = socialMatrix.getRowVector(trusteeIndex); + DenseVector userVector = userFactors.getRowVector(trusteeIndex); + float socialWeight = 0F; + for (VectorScalar socialTerm : trusteeVector) { + socialWeight += socialTerm.getValue(); + } + for (VectorScalar rateTerm : scoreMatrix.getRowVector(trusteeIndex)) { + int itemIndex = rateTerm.getIndex(); + float score = rateTerm.getValue(); + score = (score - minimumScore) / (maximumScore - minimumScore); + // compute prediction for user-item (p, j) + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + float predict = scalar.dotProduct(userVector, itemVector).getValue(); + float sum = 0F; + for (VectorScalar socialTerm : trusteeVector) { + sum += socialTerm.getValue() * scalar.dotProduct(itemFactors.getRowVector(socialTerm.getIndex()), itemVector).getValue(); + } + predict = userSocialRatio * predict + (1F - userSocialRatio) * (socialWeight > 0F ? sum / socialWeight : 0F); + // double pred = predict(p, j, false); + float error = LogisticUtility.getValue(predict) - score; + error = LogisticUtility.getGradient(predict) * error * term.getValue(); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + userDeltas.shiftValue(trusterIndex, factorIndex, (1 - userSocialRatio) * error * itemFactors.getValue(itemIndex, factorIndex)); + } + } + } + } + userFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + userDeltas.getValue(row, column) * -learnRatio); + }); + itemFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + itemDeltas.getValue(row, column) * -learnRatio); + }); + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector userVector = userFactors.getRowVector(userIndex); + DenseVector itemVector = itemFactors.getRowVector(itemIndex); + float predict = scalar.dotProduct(userVector, itemVector).getValue(); + float sum = 0F, socialWeight = 0F; + SparseVector socialVector = socialMatrix.getRowVector(userIndex); + for (VectorScalar soicalTerm : socialVector) { + float score = soicalTerm.getValue(); + DenseVector soicalFactor = userFactors.getRowVector(soicalTerm.getIndex()); + sum += score * scalar.dotProduct(soicalFactor, itemVector).getValue(); + socialWeight += score; + } + predict = userSocialRatio * predict + (1 - userSocialRatio) * (socialWeight > 0 ? sum / socialWeight : 0); + instance.setQuantityMark(denormalize(LogisticUtility.getValue(predict))); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/context/rating/SoRecModel.java b/src/main/java/com/jstarcraft/rns/model/context/rating/SoRecModel.java new file mode 100644 index 0000000..e23c4a8 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/rating/SoRecModel.java @@ -0,0 +1,160 @@ +package com.jstarcraft.rns.model.context.rating; + +import java.util.ArrayList; +import java.util.List; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.SocialModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +/** + * + * SoRec推荐器 + * + *
+ * SoRec: Social recommendation using probabilistic matrix factorization
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class SoRecModel extends SocialModel { + /** + * adaptive learn rate + */ + private DenseMatrix socialFactors; + + private float regScore, regSocial; + + private List inDegrees, outDegrees; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + userFactors = DenseMatrix.valueOf(userSize, factorSize); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + socialFactors = DenseMatrix.valueOf(userSize, factorSize); + socialFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + + regScore = configuration.getFloat("recommender.rate.social.regularization", 0.01F); + regSocial = configuration.getFloat("recommender.user.social.regularization", 0.01F); + + inDegrees = new ArrayList<>(userSize); + outDegrees = new ArrayList<>(userSize); + + for (int userIndex = 0; userIndex < userSize; userIndex++) { + int in = socialMatrix.getColumnScope(userIndex); + int out = socialMatrix.getRowScope(userIndex); + inDegrees.add(in); + outDegrees.add(out); + } + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + DenseMatrix userDeltas = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix itemDeltas = DenseMatrix.valueOf(itemSize, factorSize); + DenseMatrix socialDeltas = DenseMatrix.valueOf(userSize, factorSize); + + // ratings + for (MatrixScalar term : scoreMatrix) { + int userIdx = term.getRow(); + int itemIdx = term.getColumn(); + float score = term.getValue(); + float predict = super.predict(userIdx, itemIdx); + float error = LogisticUtility.getValue(predict) - (score - minimumScore) / (maximumScore - minimumScore); + totalError += error * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIdx, factorIndex); + float itemFactor = itemFactors.getValue(itemIdx, factorIndex); + userDeltas.shiftValue(userIdx, factorIndex, LogisticUtility.getGradient(predict) * error * itemFactor + userRegularization * userFactor); + itemDeltas.shiftValue(itemIdx, factorIndex, LogisticUtility.getGradient(predict) * error * userFactor + itemRegularization * itemFactor); + totalError += userRegularization * userFactor * userFactor + itemRegularization * itemFactor * itemFactor; + } + } + + // friends + // TODO 此处是对称矩阵,是否有方法减少计算? + for (MatrixScalar term : socialMatrix) { + int userIndex = term.getRow(); + int socialIndex = term.getColumn(); + float socialScore = term.getValue(); + // tuv ~ cik in the original paper + if (socialScore == 0F) { + continue; + } + float socialPredict = scalar.dotProduct(userFactors.getRowVector(userIndex), socialFactors.getRowVector(socialIndex)).getValue(); + float socialInDegree = inDegrees.get(socialIndex); // ~ d-(k) + float userOutDegree = outDegrees.get(userIndex); // ~ d+(i) + float weight = (float) Math.sqrt(socialInDegree / (userOutDegree + socialInDegree)); + float socialError = LogisticUtility.getValue(socialPredict) - weight * socialScore; + totalError += regScore * socialError * socialError; + + socialPredict = LogisticUtility.getGradient(socialPredict); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float socialFactor = socialFactors.getValue(socialIndex, factorIndex); + userDeltas.shiftValue(userIndex, factorIndex, regScore * socialPredict * socialError * socialFactor); + socialDeltas.shiftValue(socialIndex, factorIndex, regScore * socialPredict * socialError * userFactor + regSocial * socialFactor); + totalError += regSocial * socialFactor * socialFactor; + } + } + + userFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + userDeltas.getValue(row, column) * -learnRatio); + }); + itemFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + itemDeltas.getValue(row, column) * -learnRatio); + }); + socialFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + socialDeltas.getValue(row, column) * -learnRatio); + }); + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float predict = super.predict(userIndex, itemIndex); + predict = denormalize(LogisticUtility.getValue(predict)); + instance.setQuantityMark(predict); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/context/rating/SoRegModel.java b/src/main/java/com/jstarcraft/rns/model/context/rating/SoRegModel.java new file mode 100644 index 0000000..2c88120 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/rating/SoRegModel.java @@ -0,0 +1,152 @@ +package com.jstarcraft.rns.model.context.rating; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.algorithm.correlation.MathCorrelation; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SymmetryMatrix; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.common.reflection.ReflectionUtility; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.SocialModel; + +/** + * + * SoReg推荐器 + * + *
+ * Recommender systems with social regularization
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class SoRegModel extends SocialModel { + + private SymmetryMatrix socialCorrelations; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + userFactors = DenseMatrix.valueOf(userSize, factorSize); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + + // TODO 修改为配置枚举 + try { + Class correlationClass = (Class) Class.forName(configuration.getString("recommender.correlation.class")); + MathCorrelation correlation = ReflectionUtility.getInstance(correlationClass); + socialCorrelations = new SymmetryMatrix(socialMatrix.getRowSize()); + correlation.calculateCoefficients(socialMatrix, false, socialCorrelations::setValue); + } catch (Exception exception) { + throw new RuntimeException(exception); + } + + for (MatrixScalar term : socialCorrelations) { + float similarity = term.getValue(); + if (similarity == 0F) { + continue; + } + similarity = (1F + similarity) / 2F; + term.setValue(similarity); + } + } + + @Override + protected void doPractice() { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + DenseMatrix userDeltas = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix itemDeltas = DenseMatrix.valueOf(itemSize, factorSize); + + // ratings + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float error = predict(userIndex, itemIndex) - term.getValue(); + totalError += error * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactorValue = userFactors.getValue(userIndex, factorIndex); + float itemFactorValue = itemFactors.getValue(itemIndex, factorIndex); + userDeltas.shiftValue(userIndex, factorIndex, error * itemFactorValue + userRegularization * userFactorValue); + itemDeltas.shiftValue(itemIndex, factorIndex, error * userFactorValue + itemRegularization * itemFactorValue); + totalError += userRegularization * userFactorValue * userFactorValue + itemRegularization * itemFactorValue * itemFactorValue; + } + } + + // friends + for (int userIndex = 0; userIndex < userSize; userIndex++) { + // out links: F+ + SparseVector trusterVector = socialMatrix.getRowVector(userIndex); + for (VectorScalar term : trusterVector) { + int trusterIndex = term.getIndex(); + float trusterSimilarity = socialCorrelations.getValue(userIndex, trusterIndex); + if (!Float.isNaN(trusterSimilarity)) { + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex) - userFactors.getValue(trusterIndex, factorIndex); + userDeltas.shiftValue(userIndex, factorIndex, socialRegularization * trusterSimilarity * userFactor); + totalError += socialRegularization * trusterSimilarity * userFactor * userFactor; + } + } + } + + // in links: F- + SparseVector trusteeVector = socialMatrix.getColumnVector(userIndex); + for (VectorScalar term : trusteeVector) { + int trusteeIndex = term.getIndex(); + float trusteeSimilarity = socialCorrelations.getValue(userIndex, trusteeIndex); + if (!Float.isNaN(trusteeSimilarity)) { + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex) - userFactors.getValue(trusteeIndex, factorIndex); + userDeltas.shiftValue(userIndex, factorIndex, socialRegularization * trusteeSimilarity * userFactor); + totalError += socialRegularization * trusteeSimilarity * userFactor * userFactor; + } + } + } + } + + // end of for loop + userFactors.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue(value + userDeltas.getValue(row, column) * -learnRatio); + }); + itemFactors.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue(value + itemDeltas.getValue(row, column) * -learnRatio); + }); + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + @Override + protected float predict(int userIndex, int itemIndex) { + float score = super.predict(userIndex, itemIndex); + if (score > maximumScore) { + score = maximumScore; + } else if (score < minimumScore) { + score = minimumScore; + } + return score; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/context/rating/SocialMFModel.java b/src/main/java/com/jstarcraft/rns/model/context/rating/SocialMFModel.java new file mode 100644 index 0000000..5182db7 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/rating/SocialMFModel.java @@ -0,0 +1,144 @@ +package com.jstarcraft.rns.model.context.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.SocialModel; +import com.jstarcraft.rns.utility.LogisticUtility; + +/** + * + * SocialMF推荐器 + * + *
+ * A matrix factorization technique with trust propagation for recommendation in social networks
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class SocialMFModel extends SocialModel { + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + userFactors = DenseMatrix.valueOf(userSize, factorSize); + userFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + itemFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + } + + // TODO 需要重构 + @Override + protected void doPractice() { + DenseVector socialFactors = DenseVector.valueOf(factorSize); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + DenseMatrix userDeltas = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix itemDeltas = DenseMatrix.valueOf(itemSize, factorSize); + + // rated items + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + float predict = super.predict(userIndex, itemIndex); + float error = LogisticUtility.getValue(predict) - normalize(score); + totalError += error * error; + error = LogisticUtility.getGradient(predict) * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float itemFactor = itemFactors.getValue(itemIndex, factorIndex); + userDeltas.shiftValue(userIndex, factorIndex, error * itemFactor + userRegularization * userFactor); + itemDeltas.shiftValue(itemIndex, factorIndex, error * userFactor + itemRegularization * itemFactor); + totalError += userRegularization * userFactor * userFactor + itemRegularization * itemFactor * itemFactor; + } + } + + // social regularization + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector trusterVector = socialMatrix.getRowVector(userIndex); + int numTrusters = trusterVector.getElementSize(); + if (numTrusters == 0) { + continue; + } + socialFactors.setValues(0F); + for (VectorScalar trusterTerm : trusterVector) { + int trusterIndex = trusterTerm.getIndex(); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + socialFactors.setValue(factorIndex, socialFactors.getValue(factorIndex) + trusterTerm.getValue() * userFactors.getValue(trusterIndex, factorIndex)); + } + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float error = userFactors.getValue(userIndex, factorIndex) - socialFactors.getValue(factorIndex) / numTrusters; + userDeltas.shiftValue(userIndex, factorIndex, socialRegularization * error); + totalError += socialRegularization * error * error; + } + + // those who trusted user u + SparseVector trusteeVector = socialMatrix.getColumnVector(userIndex); + int numTrustees = trusteeVector.getElementSize(); + for (VectorScalar trusteeTerm : trusteeVector) { + int trusteeIndex = trusteeTerm.getIndex(); + trusterVector = socialMatrix.getRowVector(trusteeIndex); + numTrusters = trusterVector.getElementSize(); + if (numTrusters == 0) { + continue; + } + socialFactors.setValues(0F); + for (VectorScalar trusterTerm : trusterVector) { + int trusterIndex = trusterTerm.getIndex(); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + socialFactors.setValue(factorIndex, socialFactors.getValue(factorIndex) + trusterTerm.getValue() * userFactors.getValue(trusterIndex, factorIndex)); + } + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + userDeltas.shiftValue(userIndex, factorIndex, -socialRegularization * (trusteeTerm.getValue() / numTrustees) * (userFactors.getValue(trusteeIndex, factorIndex) - socialFactors.getValue(factorIndex) / numTrusters)); + } + } + } + // update user factors + userFactors.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue(value + userDeltas.getValue(row, column) * -learnRatio); + }); + itemFactors.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue(value + itemDeltas.getValue(row, column) * -learnRatio); + }); + + totalError *= 0.5D; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float predict = super.predict(userIndex, itemIndex); + instance.setQuantityMark(denormalize(LogisticUtility.getValue(predict))); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/context/rating/TimeSVDModel.java b/src/main/java/com/jstarcraft/rns/model/context/rating/TimeSVDModel.java new file mode 100644 index 0000000..b182925 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/rating/TimeSVDModel.java @@ -0,0 +1,478 @@ +package com.jstarcraft.rns.model.context.rating; + +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Table; +import com.google.common.collect.Table.Cell; +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.collaborative.rating.BiasedMFModel; +import com.jstarcraft.rns.model.exception.ModelException; + +/** + * + * TimeSVD++推荐器 + * + *
+ * Collaborative Filtering with Temporal Dynamics
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class TimeSVDModel extends BiasedMFModel { + + protected String instantField; + + protected int instantDimension; + + /** + * the span of days of rating timestamps + */ + private static int numDays; + + /** + * {user, mean date} + */ + private DenseVector userMeanDays; + + /** + * time decay factor(时间衰退因子) + */ + private float decay; + + /** + * number of bins over all the items + */ + private int numSections; + + /** + * {user, appender} alpha matrix + */ + private DenseMatrix userImplicitFactors; + + /** + * item's implicit influence + */ + private DenseMatrix itemImplicitFactors; + + private DenseMatrix userExplicitFactors; + + private DenseMatrix itemExplicitFactors; + + /** + * {item, bin(t)} bias matrix + */ + private DenseMatrix itemSectionBiases; + + /** + * {user, day, bias} table + */ + private Table userDayBiases; + + /** + * user bias weight parameters + */ + private DenseVector userBiasWeights; + + /** + * {user, {appender, day, value} } map + */ + private Table userDayFactors; + + /** + * {user, user scaling stable part} + */ + private DenseVector userScales; + + /** + * {user, day, day-specific scaling part} + */ + private DenseMatrix userDayScales; + + /** + * minimum, maximum timestamp + */ + private static int minTimestamp, maxTimestamp; + + /** + * matrix of time stamp + */ + // TODO 既包含trainTerm,又包含testTerm + private Table instantTabel; + + /* + * (non-Javadoc) + * + * @see net.librecommender.recommender.cf.rating.BiasedMFRecommender#setup() + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + decay = configuration.getFloat("recommender.learnrate.decay", 0.015F); + numSections = configuration.getInteger("recommender.numBins", 6); + + instantField = configuration.getString("data.model.fields.instant"); + instantDimension = model.getQualityInner(instantField); + + instantTabel = HashBasedTable.create(); + for (DataInstance sample : model) { + instantTabel.put(sample.getQualityFeature(userDimension), sample.getQualityFeature(itemDimension), sample.getQualityFeature(instantDimension)); + } + getMaxAndMinTimeStamp(); + numDays = days(maxTimestamp, minTimestamp) + 1; + // TODO 考虑重构 + userBiases = DenseVector.valueOf(userSize); + userBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + userBiasWeights = DenseVector.valueOf(userSize); + userBiasWeights.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + itemSectionBiases = DenseMatrix.valueOf(itemSize, numSections); + itemSectionBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + itemImplicitFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemImplicitFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + userImplicitFactors = DenseMatrix.valueOf(userSize, factorSize); + userImplicitFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + userDayBiases = HashBasedTable.create(); + userDayFactors = HashBasedTable.create(); + userScales = DenseVector.valueOf(userSize); + userScales.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + userDayScales = DenseMatrix.valueOf(userSize, numDays); + userDayScales.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + userExplicitFactors = DenseMatrix.valueOf(userSize, factorSize); + userExplicitFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + itemExplicitFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemExplicitFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + // global average date + float mean; + float sum = 0F; + int count = 0; + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + sum += days(instantTabel.get(userIndex, itemIndex), minTimestamp); + count++; + } + float globalMeanDays = sum / count; + // compute user's mean of rating timestamps + userMeanDays = DenseVector.valueOf(userSize); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + sum = 0F; + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + for (VectorScalar term : userVector) { + int itemIndex = term.getIndex(); + sum += days(instantTabel.get(userIndex, itemIndex), minTimestamp); + } + mean = (userVector.getElementSize() > 0) ? (sum + 0F) / userVector.getElementSize() : globalMeanDays; + userMeanDays.setValue(userIndex, mean); + } + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector rateVector = scoreMatrix.getRowVector(userIndex); + int size = rateVector.getElementSize(); + if (size == 0) { + continue; + } + for (VectorScalar term : rateVector) { + int itemExplicitIndex = term.getIndex(); + float score = term.getValue(); + // TODO 此处可以重构. + int instant = instantTabel.get(userIndex, itemExplicitIndex); + // day t + int days = days(instant, minTimestamp); + int section = section(days); + float deviation = deviation(userIndex, days); + float userBias = userBiases.getValue(userIndex); + float itemBias = itemBiases.getValue(itemExplicitIndex); + + float userScale = userScales.getValue(userIndex); + float dayScale = userDayScales.getValue(userIndex, days); + // lazy initialization + // TODO 此处可以重构. + if (!userDayBiases.contains(userIndex, days)) { + userDayBiases.put(userIndex, days, RandomUtility.randomFloat(1F)); + } + float userDayBias = userDayBiases.get(userIndex, days); + float itemSectionBias = itemSectionBiases.getValue(itemExplicitIndex, section); + // alpha_u + float userWeight = userBiasWeights.getValue(userIndex); + // mu bi(t) + float predict = meanScore + (itemBias + itemSectionBias) * (userScale + dayScale); + // bu(t) + predict += userBias + userWeight * deviation + userDayBias; + // qi * yj + DenseVector itemExplicitVector = itemExplicitFactors.getRowVector(itemExplicitIndex); + float sum = 0F; + for (VectorScalar rateTerm : rateVector) { + int itemImplicitIndex = rateTerm.getIndex(); + DenseVector itemImpilcitVector = itemImplicitFactors.getRowVector(itemImplicitIndex); + sum += scalar.dotProduct(itemImpilcitVector, itemExplicitVector).getValue(); + } + float itemWeight = (float) (size > 0 ? Math.pow(size, -0.5F) : 0F); + predict += sum * itemWeight; + // qi * pu(t) + float[] dayFactors = userDayFactors.get(userIndex, days); + if (dayFactors == null) { + dayFactors = new float[factorSize]; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + dayFactors[factorIndex] = RandomUtility.randomFloat(1F); + } + userDayFactors.put(userIndex, days, dayFactors); + } + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float qik = itemExplicitFactors.getValue(itemExplicitIndex, factorIndex); + float puk = userExplicitFactors.getValue(userIndex, factorIndex) + userImplicitFactors.getValue(userIndex, factorIndex) * deviation + dayFactors[factorIndex]; + predict += puk * qik; + } + + float error = predict - score; + totalError += error * error; + + // update bi + float sgd = error * (userScale + dayScale) + regBias * itemBias; + itemBiases.shiftValue(itemExplicitIndex, -learnRatio * sgd); + totalError += regBias * itemBias * itemBias; + + // update bi,bin(t) + sgd = error * (userScale + dayScale) + regBias * itemSectionBias; + itemSectionBiases.shiftValue(itemExplicitIndex, section, -learnRatio * sgd); + totalError += regBias * itemSectionBias * itemSectionBias; + + // update cu + sgd = error * (itemBias + itemSectionBias) + regBias * userScale; + userScales.shiftValue(userIndex, -learnRatio * sgd); + totalError += regBias * userScale * userScale; + + // update cut + sgd = error * (itemBias + itemSectionBias) + regBias * dayScale; + userDayScales.shiftValue(userIndex, days, -learnRatio * sgd); + totalError += regBias * dayScale * dayScale; + + // update bu + sgd = error + regBias * userBias; + userBiases.shiftValue(userIndex, -learnRatio * sgd); + totalError += regBias * userBias * userBias; + + // update au + sgd = error * deviation + regBias * userWeight; + userBiasWeights.shiftValue(userIndex, -learnRatio * sgd); + totalError += regBias * userWeight * userWeight; + + // update but + sgd = error + regBias * userDayBias; + float delta = userDayBias - learnRatio * sgd; + userDayBiases.put(userIndex, days, delta); + totalError += regBias * userDayBias * userDayBias; + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userExplicitFactor = userExplicitFactors.getValue(userIndex, factorIndex); + float itemExplicitFactor = itemExplicitFactors.getValue(itemExplicitIndex, factorIndex); + float userImplicitFactor = userImplicitFactors.getValue(userIndex, factorIndex); + delta = dayFactors[factorIndex]; + + // TODO 此处可以整合操作 + sum = 0F; + // update userExplicitFactor + sgd = error * itemExplicitFactor + userRegularization * userExplicitFactor; + userExplicitFactors.shiftValue(userIndex, factorIndex, -learnRatio * sgd); + totalError += userRegularization * userExplicitFactor * userExplicitFactor; + + // update itemExplicitFactors + for (VectorScalar rateTerm : rateVector) { + int itemImplicitIndex = rateTerm.getIndex(); + sum += itemImplicitFactors.getValue(itemImplicitIndex, factorIndex); + } + sgd = error * (userExplicitFactor + userImplicitFactor * deviation + delta + itemWeight * sum) + itemRegularization * itemExplicitFactor; + itemExplicitFactors.shiftValue(itemExplicitIndex, factorIndex, -learnRatio * sgd); + totalError += itemRegularization * itemExplicitFactor * itemExplicitFactor; + + // update userImplicitFactors + sgd = error * itemExplicitFactor * deviation + userRegularization * userImplicitFactor; + userImplicitFactors.shiftValue(userIndex, factorIndex, -learnRatio * sgd); + totalError += userRegularization * userImplicitFactor * userImplicitFactor; + + // update itemImplicitFactors + // TODO 此处可以整合操作 + for (VectorScalar rateTerm : rateVector) { + int itemImplicitIndex = rateTerm.getIndex(); + float itemImplicitFactor = itemImplicitFactors.getValue(itemImplicitIndex, factorIndex); + sgd = error * itemWeight * itemExplicitFactor + itemRegularization * itemImplicitFactor; + itemImplicitFactors.shiftValue(itemImplicitIndex, factorIndex, -learnRatio * sgd); + totalError += itemRegularization * itemImplicitFactor * itemImplicitFactor; + } + + // update pkt + sgd = error * itemExplicitFactor + userRegularization * delta; + totalError += userRegularization * delta * delta; + delta = delta - learnRatio * sgd; + dayFactors[factorIndex] = delta; + } + + } + } + + totalError *= 0.5D; + if (isConverged(epocheIndex)) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + + /** + * predict a specific rating for user userIdx on item itemIdx. + * + * @param userIndex user index + * @param itemIndex item index + * @return predictive rating for user userIdx on item itemIdx + * @throws ModelException if error occurs + */ + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + // retrieve the test rating timestamp + int instant = instance.getQualityFeature(instantDimension); + int days = days(instant, minTimestamp); + int section = section(days); + float deviation = deviation(userIndex, days); + float value = meanScore; + + // bi(t): eq. (12) + value += (itemBiases.getValue(itemIndex) + itemSectionBiases.getValue(itemIndex, section)) * (userScales.getValue(userIndex) + userDayScales.getValue(userIndex, days)); + // bu(t): eq. (9) + value += (userBiases.getValue(userIndex) + userBiasWeights.getValue(userIndex) * deviation + (userDayBiases.contains(userIndex, days) ? userDayBiases.get(userIndex, days) : 0D)); + + // qi * yj + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + + float sum = 0F; + DenseVector itemExplicitVector = itemExplicitFactors.getRowVector(itemIndex); + for (VectorScalar term : userVector) { + DenseVector itemImplicitVector = itemImplicitFactors.getRowVector(term.getIndex()); + sum += scalar.dotProduct(itemImplicitVector, itemExplicitVector).getValue(); + } + float weight = (float) (userVector.getElementSize() > 0 ? Math.pow(userVector.getElementSize(), -0.5F) : 0F); + value += sum * weight; + + // qi * pu(t) + float[] dayFactors = userDayFactors.get(userIndex, days); + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float itemExplicitFactor = itemExplicitFactors.getValue(itemIndex, factorIndex); + // eq. (13) + float userExplicitFactor = userExplicitFactors.getValue(userIndex, factorIndex) + userImplicitFactors.getValue(userIndex, factorIndex) * deviation; + userExplicitFactor += (dayFactors == null ? 0F : dayFactors[factorIndex]); + value += userExplicitFactor * itemExplicitFactor; + } + instance.setQuantityMark(value); + } + + /** + * get the time deviation for a specific timestamp + * + * @param userIndex the inner id of a user + * @param days the time stamp + * @return the time deviation for a specific timestamp t w.r.t the mean date tu + */ + private float deviation(int userIndex, int days) { + float mean = userMeanDays.getValue(userIndex); + // date difference in days + float deviation = days - mean; + return (float) (Math.signum(deviation) * Math.pow(Math.abs(deviation), decay)); + } + + /** + * get the bin number for a specific time stamp + * + * @param days time stamp of a day + * @return the bin number (starting from 0..numBins-1) for a specific timestamp + * t; + */ + // 将时间戳分段 + private int section(int days) { + return (int) (days / (numDays + 0D) * numSections); + } + + /** + * get the number of days for a given time difference + * + * @param duration the difference between two time stamps + * @return number of days for a given time difference + */ + private static int days(long duration) { + return (int) TimeUnit.MILLISECONDS.toDays(duration); + } + + /** + * get the number of days between two timestamps + * + * @param t1 time stamp 1 + * @param t2 time stamp 2 + * @return number of days between two timestamps + */ + private static int days(int t1, int t2) { + return days(Math.abs(t1 - t2)); + } + + /** + * get the maximum and minimum time stamps in the time matrix + * + */ + private void getMaxAndMinTimeStamp() { + minTimestamp = Integer.MAX_VALUE; + maxTimestamp = Integer.MIN_VALUE; + + for (Cell cell : instantTabel.cellSet()) { + int timeStamp = cell.getValue(); + if (timeStamp < minTimestamp) { + minTimestamp = timeStamp; + } + + if (timeStamp > maxTimestamp) { + maxTimestamp = timeStamp; + } + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/context/rating/TrustMFModel.java b/src/main/java/com/jstarcraft/rns/model/context/rating/TrustMFModel.java new file mode 100644 index 0000000..ded217c --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/rating/TrustMFModel.java @@ -0,0 +1,332 @@ +package com.jstarcraft.rns.model.context.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.SocialModel; +import com.jstarcraft.rns.model.exception.ModelException; +import com.jstarcraft.rns.utility.LogisticUtility; + +/** + * + * TrustMF推荐器 + * + *
+ * Social Collaborative Filtering by Trust
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class TrustMFModel extends SocialModel { + /** + * truster model + */ + private DenseMatrix trusterUserFactors, trusterItemFactors, trusteeUserDeltas; + + /** + * trustee model + */ + private DenseMatrix trusteeUserFactors, trusteeItemFactors, trusterUserDeltas; + + /** + * model selection identifier + */ + private String mode; + + // TODO 需要重构 + private void prepareByTruster() { + trusterUserFactors = DenseMatrix.valueOf(userSize, factorSize); + trusterUserFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + trusteeUserDeltas = DenseMatrix.valueOf(userSize, factorSize); + trusteeUserDeltas.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + trusterItemFactors = DenseMatrix.valueOf(itemSize, factorSize); + trusterItemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + } + + // TODO 需要重构 + private void prepareByTrustee() { + trusterUserDeltas = DenseMatrix.valueOf(userSize, factorSize); + trusterUserDeltas.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + trusteeUserFactors = DenseMatrix.valueOf(userSize, factorSize); + trusteeUserFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + trusteeItemFactors = DenseMatrix.valueOf(itemSize, factorSize); + trusteeItemFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(RandomUtility.randomFloat(1F)); + }); + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + mode = configuration.getString("recommender.social.model", "T"); + // algoName = "TrustMF (" + model + ")"; + switch (mode) { + case "Tr": + prepareByTruster(); + break; + case "Te": + prepareByTrustee(); + break; + case "T": + default: + prepareByTruster(); + prepareByTrustee(); + } + } + + /** + * Build TrusterMF model: Br*Vr + * + * @throws ModelException if error occurs + */ + private void trainByTruster(DefaultScalar scalar) { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + // gradients of trusterUserTrusterFactors, + // trusterUserTrusteeFactors, trusterItemFactors + DenseMatrix trusterGradients = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix trusteeGradients = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix itemGradients = DenseMatrix.valueOf(itemSize, factorSize); + + // rate matrix + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + float predict = predict(userIndex, itemIndex); + float error = LogisticUtility.getValue(predict) - normalize(score); + totalError += error * error; + error = LogisticUtility.getGradient(predict) * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float trusterUserFactor = trusterUserFactors.getValue(userIndex, factorIndex); + float trusterItemFactor = trusterItemFactors.getValue(itemIndex, factorIndex); + trusterGradients.shiftValue(userIndex, factorIndex, error * trusterItemFactor + userRegularization * trusterUserFactor); + itemGradients.shiftValue(itemIndex, factorIndex, error * trusterUserFactor + itemRegularization * trusterItemFactor); + totalError += userRegularization * trusterUserFactor * trusterUserFactor + itemRegularization * trusterItemFactor * trusterItemFactor; + } + } + + // social matrix + for (MatrixScalar term : socialMatrix) { + int trusterIndex = term.getRow(); + int trusteeIndex = term.getColumn(); + float score = term.getValue(); + DenseVector trusteeVector = trusteeUserDeltas.getRowVector(trusteeIndex); + DenseVector trusterVector = trusterUserFactors.getRowVector(trusterIndex); + float predict = scalar.dotProduct(trusteeVector, trusterVector).getValue(); + float error = LogisticUtility.getValue(predict) - score; + totalError += socialRegularization * error * error; + error = LogisticUtility.getGradient(predict) * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float trusterUserFactor = trusterUserFactors.getValue(trusterIndex, factorIndex); + float trusterUserDelta = trusteeUserDeltas.getValue(trusteeIndex, factorIndex); + trusterGradients.shiftValue(trusterIndex, factorIndex, socialRegularization * error * trusterUserDelta + userRegularization * trusterUserFactor); + trusteeGradients.shiftValue(trusteeIndex, factorIndex, socialRegularization * error * trusterUserFactor + userRegularization * trusterUserDelta); + totalError += userRegularization * trusterUserFactor * trusterUserFactor + userRegularization * trusterUserDelta * trusterUserDelta; + } + } + + trusteeUserDeltas.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + trusteeGradients.getValue(row, column) * -learnRatio); + }); + trusterUserFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + trusterGradients.getValue(row, column) * -learnRatio); + }); + trusterItemFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + itemGradients.getValue(row, column) * -learnRatio); + }); + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + } + } + + /** + * Build TrusteeMF model: We*Ve + * + * @throws ModelException if error occurs + */ + private void trainByTrustee(DefaultScalar scalar) { + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + // gradients of trusteeUserTrusterFactors, + // trusteeUserTrusteeFactors, trusteeItemFactors + DenseMatrix trusterGradients = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix trusteeGradients = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix itemGradients = DenseMatrix.valueOf(itemSize, factorSize); + + // rate matrix + for (MatrixScalar term : scoreMatrix) { + int userIndex = term.getRow(); + int itemIndex = term.getColumn(); + float score = term.getValue(); + float predict = predict(userIndex, itemIndex); + float error = LogisticUtility.getValue(predict) - normalize(score); + totalError += error * error; + error = LogisticUtility.getGradient(predict) * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float trusteeUserFactor = trusteeUserFactors.getValue(userIndex, factorIndex); + float trusteeItemFactor = trusteeItemFactors.getValue(itemIndex, factorIndex); + trusteeGradients.shiftValue(userIndex, factorIndex, error * trusteeItemFactor + userRegularization * trusteeUserFactor); + itemGradients.shiftValue(itemIndex, factorIndex, error * trusteeUserFactor + itemRegularization * trusteeItemFactor); + totalError += userRegularization * trusteeUserFactor * trusteeUserFactor + itemRegularization * trusteeItemFactor * trusteeItemFactor; + } + } + + // social matrix + for (MatrixScalar term : socialMatrix) { + int trusterIndex = term.getRow(); + int trusteeIndex = term.getColumn(); + float score = term.getValue(); + DenseVector trusterVector = trusterUserDeltas.getRowVector(trusterIndex); + DenseVector trusteeVector = trusteeUserFactors.getRowVector(trusteeIndex); + float predict = scalar.dotProduct(trusterVector, trusteeVector).getValue(); + float error = LogisticUtility.getValue(predict) - score; + totalError += socialRegularization * error * error; + error = LogisticUtility.getGradient(predict) * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float trusteeUserFactor = trusteeUserFactors.getValue(trusteeIndex, factorIndex); + float trusteeUserDelta = trusterUserDeltas.getValue(trusterIndex, factorIndex); + trusteeGradients.shiftValue(trusteeIndex, factorIndex, socialRegularization * error * trusteeUserDelta + userRegularization * trusteeUserFactor); + trusterGradients.shiftValue(trusterIndex, factorIndex, socialRegularization * error * trusteeUserFactor + userRegularization * trusteeUserDelta); + totalError += userRegularization * trusteeUserFactor * trusteeUserFactor + userRegularization * trusteeUserDelta * trusteeUserDelta; + } + } + + trusterUserDeltas.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + trusterGradients.getValue(row, column) * -learnRatio); + }); + trusteeUserFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + trusteeGradients.getValue(row, column) * -learnRatio); + }); + trusteeItemFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + itemGradients.getValue(row, column) * -learnRatio); + }); + + totalError *= 0.5D; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + currentError = totalError; + } + } + + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + switch (mode) { + case "Tr": + trainByTruster(scalar); + break; + case "Te": + trainByTrustee(scalar); + break; + case "T": + default: + trainByTruster(scalar); + trainByTrustee(scalar); + } + } + + /** + * This is the method used by the paper authors + * + * @param iter number of iteration + */ + @Override + protected void isLearned(int iter) { + // TODO 此处需要重构(修改为配置) + if (iter == 10) { + learnRatio *= 0.6; + } else if (iter == 30) { + learnRatio *= 0.333; + } else if (iter == 100) { + learnRatio *= 0.5; + } + currentError = totalError; + } + + @Override + protected float predict(int userIndex, int itemIndex) { + DefaultScalar scalar = DefaultScalar.getInstance(); + float value; + DenseVector userVector; + DenseVector itemVector; + switch (mode) { + case "Tr": + userVector = trusterUserFactors.getRowVector(userIndex); + itemVector = trusterItemFactors.getRowVector(itemIndex); + value = scalar.dotProduct(userVector, itemVector).getValue(); + break; + case "Te": + userVector = trusteeUserFactors.getRowVector(userIndex); + itemVector = trusteeItemFactors.getRowVector(itemIndex); + value = scalar.dotProduct(userVector, itemVector).getValue(); + break; + case "T": + default: + DenseVector trusterUserVector = trusterUserFactors.getRowVector(userIndex); + DenseVector trusteeUserVector = trusteeUserFactors.getRowVector(userIndex); + DenseVector trusterItemVector = trusterItemFactors.getRowVector(itemIndex); + DenseVector trusteeItemVector = trusteeItemFactors.getRowVector(itemIndex); + value = 0F; + for (int index = 0; index < factorSize; index++) { + value += (trusterUserVector.getValue(index) + trusteeUserVector.getValue(index)) * (trusterItemVector.getValue(index) + trusteeItemVector.getValue(index)); + } + value /= 4F; + } + return value; + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float predict = predict(userIndex, itemIndex); + predict = denormalize(LogisticUtility.getValue(predict)); + instance.setQuantityMark(predict); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/context/rating/TrustSVDModel.java b/src/main/java/com/jstarcraft/rns/model/context/rating/TrustSVDModel.java new file mode 100644 index 0000000..e0ef91a --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/rating/TrustSVDModel.java @@ -0,0 +1,335 @@ +package com.jstarcraft.rns.model.context.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.SocialModel; +import com.jstarcraft.rns.model.exception.ModelException; + +/** + * + * TrustSVD推荐器 + * + *
+ * TrustSVD: Collaborative Filtering with Both the Explicit and Implicit Influence of User Trust and of Item Ratings
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class TrustSVDModel extends SocialModel { + + private DenseMatrix itemExplicitFactors; + + /** + * impitemExplicitFactors denotes the implicit influence of items rated by user + * u in the past on the ratings of unknown items in the future. + */ + private DenseMatrix itemImplicitFactors; + + private DenseMatrix trusterFactors; + + /** + * the user-specific latent appender vector of users (trustees)trusted by user u + */ + private DenseMatrix trusteeFactors; + + /** + * weights of users(trustees) trusted by user u + */ + private DenseVector trusteeWeights; + + /** + * weights of users(trusters) who trust user u + */ + private DenseVector trusterWeights; + + /** + * weights of items rated by user u + */ + private DenseVector itemWeights; + + /** + * user biases and item biases + */ + private DenseVector userBiases, itemBiases; + + /** + * bias regularization + */ + private float regBias; + + /** + * initial the model + * + * @throws ModelException if error occurs + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // trusterFactors.init(1.0); + // itemExplicitFactors.init(1.0); + regBias = configuration.getFloat("recommender.bias.regularization", 0.01F); + + // initialize userBiases and itemBiases + // TODO 考虑重构 + userBiases = DenseVector.valueOf(userSize); + userBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemBiases = DenseVector.valueOf(itemSize); + itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + // initialize trusteeFactors and impitemExplicitFactors + trusterFactors = userFactors; + trusteeFactors = DenseMatrix.valueOf(userSize, factorSize); + trusteeFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + itemExplicitFactors = itemFactors; + itemImplicitFactors = DenseMatrix.valueOf(itemSize, factorSize); + itemImplicitFactors.iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(distribution.sample().floatValue()); + }); + + // initialize trusteeWeights, trusterWeights, impItemWeights + // TODO 考虑重构 + trusteeWeights = DenseVector.valueOf(userSize); + trusterWeights = DenseVector.valueOf(userSize); + itemWeights = DenseVector.valueOf(itemSize); + int socialCount; + for (int userIndex = 0; userIndex < userSize; userIndex++) { + socialCount = socialMatrix.getColumnScope(userIndex); + trusteeWeights.setValue(userIndex, (float) (socialCount > 0 ? 1F / Math.sqrt(socialCount) : 1F)); + socialCount = socialMatrix.getRowScope(userIndex); + trusterWeights.setValue(userIndex, (float) (socialCount > 0 ? 1F / Math.sqrt(socialCount) : 1F)); + } + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + int count = scoreMatrix.getColumnScope(itemIndex); + itemWeights.setValue(itemIndex, (float) (count > 0 ? 1F / Math.sqrt(count) : 1F)); + } + } + + /** + * train model process + * + * @throws ModelException if error occurs + */ + @Override + protected void doPractice() { + DefaultScalar scalar = DefaultScalar.getInstance(); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + // temp user Factors and trustee factors + DenseMatrix trusterDeltas = DenseMatrix.valueOf(userSize, factorSize); + DenseMatrix trusteeDeltas = DenseMatrix.valueOf(userSize, factorSize); + + for (MatrixScalar term : scoreMatrix) { + int trusterIndex = term.getRow(); // user userIdx + int itemExplicitIndex = term.getColumn(); // item itemIdx + // real rating on item itemIdx rated by user userIdx + float score = term.getValue(); + // To speed up, directly access the prediction instead of + // invoking "predictRating = predict(userIdx,itemIdx)" + float userBias = userBiases.getValue(trusterIndex); + float itemBias = itemBiases.getValue(itemExplicitIndex); + // TODO 考虑重构减少迭代 + DenseVector trusterVector = trusterFactors.getRowVector(trusterIndex); + DenseVector itemExplicitVector = itemExplicitFactors.getRowVector(itemExplicitIndex); + float predict = meanScore + userBias + itemBias + scalar.dotProduct(trusterVector, itemExplicitVector).getValue(); + + // get the implicit influence predict rating using items rated + // by user userIdx + SparseVector rateVector = scoreMatrix.getRowVector(trusterIndex); + if (rateVector.getElementSize() > 0) { + float sum = 0F; + for (VectorScalar rateTerm : rateVector) { + int itemImplicitIndex = rateTerm.getIndex(); + DenseVector itemImplicitVector = itemImplicitFactors.getRowVector(itemImplicitIndex); + sum += scalar.dotProduct(itemImplicitVector, itemExplicitVector).getValue(); + } + predict += sum / Math.sqrt(rateVector.getElementSize()); + } + + // the user-specific influence of users (trustees)trusted by + // user userIdx + SparseVector socialVector = socialMatrix.getRowVector(trusterIndex); + if (socialVector.getElementSize() > 0) { + float sum = 0F; + for (VectorScalar socialTerm : socialVector) { + int trusteeIndex = socialTerm.getIndex(); + sum += scalar.dotProduct(trusteeFactors.getRowVector(trusteeIndex), itemExplicitVector).getValue(); + } + predict += sum / Math.sqrt(socialVector.getElementSize()); + } + float error = predict - score; + totalError += error * error; + + float trusterDenominator = (float) Math.sqrt(rateVector.getElementSize()); + float trusteeDenominator = (float) Math.sqrt(socialVector.getElementSize()); + + float trusterWeight = 1F / trusterDenominator; + float itemExplicitWeight = itemWeights.getValue(itemExplicitIndex); + + // update factors + // stochastic gradient descent sgd + float sgd = error + regBias * trusterWeight * userBias; + userBiases.shiftValue(trusterIndex, -learnRatio * sgd); + sgd = error + regBias * itemExplicitWeight * itemBias; + itemBiases.shiftValue(itemExplicitIndex, -learnRatio * sgd); + totalError += regBias * trusterWeight * userBias * userBias + regBias * itemExplicitWeight * itemBias * itemBias; + + float[] itemSums = new float[factorSize]; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float sum = 0F; + for (VectorScalar rateTerm : rateVector) { + int itemImplicitIndex = rateTerm.getIndex(); + sum += itemImplicitFactors.getValue(itemImplicitIndex, factorIndex); + } + itemSums[factorIndex] = trusterDenominator > 0F ? sum / trusterDenominator : sum; + } + + float[] trusteesSums = new float[factorSize]; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float sum = 0F; + for (VectorScalar socialTerm : socialVector) { + int trusteeIndex = socialTerm.getIndex(); + sum += trusteeFactors.getValue(trusteeIndex, factorIndex); + } + trusteesSums[factorIndex] = trusteeDenominator > 0F ? sum / trusteeDenominator : sum; + } + + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = trusterFactors.getValue(trusterIndex, factorIndex); + float itemFactor = itemExplicitFactors.getValue(itemExplicitIndex, factorIndex); + float userDelta = error * itemFactor + userRegularization * trusterWeight * userFactor; + float itemDelta = error * (userFactor + itemSums[factorIndex] + trusteesSums[factorIndex]) + itemRegularization * itemExplicitWeight * itemFactor; + // update trusterDeltas + trusterDeltas.shiftValue(trusterIndex, factorIndex, userDelta); + // update itemExplicitFactors + itemExplicitFactors.shiftValue(itemExplicitIndex, factorIndex, -learnRatio * itemDelta); + totalError += userRegularization * trusterWeight * userFactor * userFactor + itemRegularization * itemExplicitWeight * itemFactor * itemFactor; + + // update itemImplicitFactors + for (VectorScalar rateTerm : rateVector) { + int itemImplicitIndex = rateTerm.getIndex(); + float itemImplicitFactor = itemImplicitFactors.getValue(itemImplicitIndex, factorIndex); + float itemImplicitWeight = itemWeights.getValue(itemImplicitIndex); + float itemImplicitDelta = error * itemFactor / trusterDenominator + itemRegularization * itemImplicitWeight * itemImplicitFactor; + itemImplicitFactors.shiftValue(itemImplicitIndex, factorIndex, -learnRatio * itemImplicitDelta); + totalError += itemRegularization * itemImplicitWeight * itemImplicitFactor * itemImplicitFactor; + } + + // update trusteeDeltas + for (VectorScalar socialTerm : socialVector) { + int trusteeIndex = socialTerm.getIndex(); + float trusteeFactor = trusteeFactors.getValue(trusteeIndex, factorIndex); + float trusteeWeight = trusteeWeights.getValue(trusteeIndex); + float trusteeDelta = error * itemFactor / trusteeDenominator + userRegularization * trusteeWeight * trusteeFactor; + trusteeDeltas.shiftValue(trusteeIndex, factorIndex, trusteeDelta); + totalError += userRegularization * trusteeWeight * trusteeFactor * trusteeFactor; + } + } + } + + for (MatrixScalar socialTerm : socialMatrix) { + int trusterIndex = socialTerm.getRow(); + int trusteeIndex = socialTerm.getColumn(); + float score = socialTerm.getValue(); + DenseVector trusterVector = trusterFactors.getRowVector(trusterIndex); + DenseVector trusteeVector = trusteeFactors.getRowVector(trusteeIndex); + float predtict = scalar.dotProduct(trusterVector, trusteeVector).getValue(); + float error = predtict - score; + totalError += socialRegularization * error * error; + error = socialRegularization * error; + + float trusterWeight = trusterWeights.getValue(trusterIndex); + // update trusterDeltas,trusteeDeltas + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float trusterFactor = trusterFactors.getValue(trusterIndex, factorIndex); + float trusteeFactor = trusteeFactors.getValue(trusteeIndex, factorIndex); + trusterDeltas.shiftValue(trusterIndex, factorIndex, error * trusteeFactor + socialRegularization * trusterWeight * trusterFactor); + trusteeDeltas.shiftValue(trusteeIndex, factorIndex, error * trusterFactor); + totalError += socialRegularization * trusterWeight * trusterFactor * trusterFactor; + } + } + + trusterFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + trusterDeltas.getValue(row, column) * -learnRatio); + }); + trusteeFactors.iterateElement(MathCalculator.PARALLEL, (element) -> { + int row = element.getRow(); + int column = element.getColumn(); + float value = element.getValue(); + element.setValue(value + trusteeDeltas.getValue(row, column) * -learnRatio); + }); + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } // end of training + } + + /** + * predict a specific rating for user userIdx on item itemIdx. + * + * @param userIndex user index + * @param itemIndex item index + * @return predictive rating for user userIdx on item itemIdx + * @throws ModelException if error occurs + */ + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + DefaultScalar scalar = DefaultScalar.getInstance(); + DenseVector trusterVector = trusterFactors.getRowVector(userIndex); + DenseVector itemExplicitVector = itemExplicitFactors.getRowVector(itemIndex); + float value = meanScore + userBiases.getValue(userIndex) + itemBiases.getValue(itemIndex) + scalar.dotProduct(trusterVector, itemExplicitVector).getValue(); + + // the implicit influence of items rated by user in the past on the + // ratings of unknown items in the future. + SparseVector rateVector = scoreMatrix.getRowVector(userIndex); + if (rateVector.getElementSize() > 0) { + float sum = 0F; + for (VectorScalar rateTerm : rateVector) { + itemIndex = rateTerm.getIndex(); + // TODO 考虑重构减少迭代 + DenseVector itemImplicitVector = itemImplicitFactors.getRowVector(itemIndex); + sum += scalar.dotProduct(itemImplicitVector, itemExplicitVector).getValue(); + } + value += sum / Math.sqrt(rateVector.getElementSize()); + } + + // the user-specific influence of users (trustees)trusted by user u + SparseVector socialVector = socialMatrix.getRowVector(userIndex); + if (socialVector.getElementSize() > 0) { + float sum = 0F; + for (VectorScalar socialTerm : socialVector) { + userIndex = socialTerm.getIndex(); + DenseVector trusteeVector = trusteeFactors.getRowVector(userIndex); + sum += scalar.dotProduct(trusteeVector, itemExplicitVector).getValue(); + } + value += sum / Math.sqrt(socialVector.getElementSize()); + } + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/context/social.txt b/src/main/java/com/jstarcraft/rns/model/context/social.txt new file mode 100644 index 0000000..15aa324 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/context/social.txt @@ -0,0 +1,29 @@ +人际关系是一种十分复杂的现象,我们可以把日常生活中所存在的人际关系划分为以下四个维度. +四个维度的不同结合,就出现了许多不同的人际关系. +1.基缘维 +所谓基缘,是指构成人际关系的最基本的因素,包括血缘、地缘、业缘、趣缘等. +血缘构成的人际关系,泛指因血缘和姻缘联系而交往形成的人际关系,如父子、母女、叔侄、夫妻、姨表、婆媳等等. +地缘关系是由于人们共同生活,活动在相同的空间而交往产生的关系,如“老乡”、“邻居”、“同单位”等等,它带有强烈的地方色彩. +业缘人际关系,在人际关系整体中占的比例大,对于社会的影响也大,如“师徒”、“同学”、“买卖双方”等等.所有以职业、行业、专业、事业为媒介建立的关系都属此类.良好的业缘人际关系,是促进组织顺利活动的动力. +趣缘人际关系,是指人们在交往过程中,因趣味相投而建立的朋友、球友、棋友等多种人际关系.趣缘人际关系是以人们之间的感情和趣味为介质而建立的,它对于协调人与人的关系有重要的意义. +2.间距维 +这是指在交往过程中,交往双方相处的距离的远近对双方关系的影响. +这里所谈的间距,有两方面含义. +一是指自然间距,就是人们在客观环境中所处的距离.这种距离的远近,对于人们在行为上的相互接触,起着影响作用. +另一方面是指心理距离,就是指人们的心理活动是不是在同一水平线上,有无差距,这种距离的远近,对于人们的心理上相互沟通起着影响作用. +交往的双方如若建立良好的关系,必须在这两种距离都很适宜的时候,才能得以实现.否则,天天处在同一环境中,不但不能和睦相处反而相互磨擦. +我们知道,在同一群体内生活的一些人,虽然所处的客观环境相同,但由于心理距离相隔甚远,所以,难以做到心理相容,因而矛盾不断,由此削减了群体的力量,给活动带来诸多不利. +3.交频维 +这是指交往双方频率的多少对人际关系的影响. +在群体中,人与人之间能不能结成良好的关系,仅具备基缘、间距因素还不够,还缺乏把基缘相联、把间距相融的中间递质,而交频则充当了这一递质. +在群体中,有些人之间,虽然有某方面的基缘,如有一定的亲戚关系,但由于二者的交往频率非常少,难以相互了解,也只能处在不冷不热的水平. +中国有句俗语叫“走亲戚”,就包括有这样一层含义,亲戚是走出来的,两家越走越亲,因为有了走,才能相互了解,相互了解了,关系也就更亲密了. +间距之间要做得相融,也需要交频来起作用,交往频率多了,间距也就小了.否则,虽相处同一空间,心理间距也永远不能相容,甚至会反目为仇.我们在日常生活中所见的“隔墙邻居结为仇”等,就属这一类. +所以说,交往频率不仅充当了“递质”,而且在某种意义上,它还对交往程度的高低起决定作用,影响人际关系的水平. +4.信传维 +这是指交往过程中信息传递对形成良好人际关系的影响. +信息传递从表面看来,它似乎是交往者的携带物,不是独立因素.但事实上,它却是人际关系形成中的一个重要因素. +人们在建立关系时,如基缘、间距、交频都较合适,但就是缺乏信传,那么,在这种情况下,人际关系是建立不起来的. +因为没有传递,也许人们之间各方面较接近,但却无法相互了解,不能相互了解,是没办法建立良好关系的. +平时我们见到有些人闹矛盾,其实他们之间并没有根本利益的冲突,只不过是话没说透,两个人联不到一起造成的.而这个“话没说透”,就是“信传”没有及时起到作用,因此引起风波.可见,没有“信传”因素的参与,是不可能建立良好的人际关系的,而没有良好的人际关系作基础,群体也就无凝聚力可言. +以上所述四种基本维度,是人际关系形成的根本,四种维度之间,相互联系、相互作用.各种维度的不同结合,就会产生不同的人际关系. \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/dl4j.txt b/src/main/java/com/jstarcraft/rns/model/dl4j/dl4j.txt new file mode 100644 index 0000000..f3acada --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/dl4j.txt @@ -0,0 +1,10 @@ +https://github.com/IceS2388/paper-recommender + +Next RecSys Library: +https://github.com/wubinzzu/NeuRec + +A unified, comprehensive and efficient recommendation library: +https://github.com/RUCAIBox/RecBole + +Easy-to-use,Modular and Extendible package of deep-learning based CTR models: +https://github.com/shenweichen/DeepCTR diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ncf.txt b/src/main/java/com/jstarcraft/rns/model/dl4j/ncf.txt new file mode 100644 index 0000000..5a1fee6 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ncf.txt @@ -0,0 +1,2 @@ +Neural-Collaborative-Filtering: +https://github.com/enningxie/Neural-Collaborative-Filtering diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ngcf.txt b/src/main/java/com/jstarcraft/rns/model/dl4j/ngcf.txt new file mode 100644 index 0000000..109671e --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ngcf.txt @@ -0,0 +1 @@ +https://github.com/huangtinglin/NGCF-PyTorch diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEConfiguration.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEConfiguration.java new file mode 100644 index 0000000..37607f3 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEConfiguration.java @@ -0,0 +1,148 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import java.util.Collection; +import java.util.Map; + +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.conf.memory.MemoryReport; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * + * CDAE配置 + * + *
+ * Collaborative Denoising Auto-Encoders for Top-N Recommender Systems
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class CDAEConfiguration extends FeedForwardLayer { + + private CDAEParameter cdaeParameter; + + CDAEConfiguration() { + // We need a no-arg constructor so we can deserialize the configuration + // from JSON or YAML format + // Without this, you will likely get an exception like the following: + // com.fasterxml.jackson.databind.JsonMappingException: No suitable + // constructor found for type [simple type, class + // org.deeplearning4j.examples.misc.customlayers.layer.CustomLayer]: can + // not instantiate from JSON object (missing default constructor or + // creator, or perhaps need to add/enable type information?) + } + + private CDAEConfiguration(Builder builder) { + super(builder); + this.cdaeParameter = new CDAEParameter(builder.numberOfUsers); + } + + @Override + public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams) { + // The instantiate method is how we go from the configuration class + // (i.e., this class) to the implementation class + // (i.e., a CustomLayerImpl instance) + // For the most part, it's the same for each type of layer + + CDAELayer myCustomLayer = new CDAELayer(conf); + myCustomLayer.setListeners(iterationListeners); // Set the iteration + // listeners, if any + myCustomLayer.setIndex(layerIndex); // Integer index of the layer + + // Parameter view array: In Deeplearning4j, the network parameters for + // the entire network (all layers) are + // allocated in one big array. The relevant section of this parameter + // vector is extracted out for each layer, + // (i.e., it's a "view" array in that it's a subset of a larger array) + // This is a row vector, with length equal to the number of parameters + // in the layer + myCustomLayer.setParamsViewArray(layerParamsView); + + // Initialize the layer parameters. For example, + // Note that the entries in paramTable (2 entries here: a weight array + // of shape [nIn,nOut] and biases of shape [1,nOut] + // are in turn a view of the 'layerParamsView' array. + Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + myCustomLayer.setParamTable(paramTable); + myCustomLayer.setConf(conf); + return myCustomLayer; + } + + @Override + public ParamInitializer initializer() { + // This method returns the parameter initializer for this type of layer + // In this case, we can use the DefaultParamInitializer, which is the + // same one used for DenseLayer + // For more complex layers, you may need to implement a custom parameter + // initializer + // See the various parameter initializers here: + // https://github.com/deeplearning4j/deeplearning4j/tree/master/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/params + return cdaeParameter; + } + + @Override + public double getL1ByParam(String paramName) { + switch (paramName) { + case CDAEParameter.WEIGHT_KEY: + return l1; + case CDAEParameter.BIAS_KEY: + return l1Bias; + case CDAEParameter.USER_KEY: + return l1; + default: + throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\""); + } + } + + @Override + public double getL2ByParam(String paramName) { + switch (paramName) { + case CDAEParameter.WEIGHT_KEY: + return l2; + case CDAEParameter.BIAS_KEY: + return l2Bias; + case CDAEParameter.USER_KEY: + return l2; + default: + throw new IllegalArgumentException("Unknown parameter name: \"" + paramName + "\""); + } + } + + // Here's an implementation of a builder pattern, to allow us to easily + // configure the layer + // Note that we are inheriting all of the FeedForwardLayer.Builder options: + // things like n + public static class Builder extends FeedForwardLayer.Builder { + private int numberOfUsers; + + @Override + @SuppressWarnings("unchecked") // To stop warnings about unchecked cast. + // Not required. + public CDAEConfiguration build() { + return new CDAEConfiguration(this); + } + + public Builder setNumUsers(int numUsers) { + this.numberOfUsers = numUsers; + return this; + } + } + + @Override + public LayerMemoryReport getMemoryReport(InputType inputType) { + return new LayerMemoryReport.Builder(layerName, CDAEConfiguration.class, inputType, inputType).standardMemory(0, 0) // No params + .workingMemory(0, 0, 0, 0).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching + .build(); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAELayer.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAELayer.java new file mode 100644 index 0000000..7d3c711 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAELayer.java @@ -0,0 +1,117 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import java.util.Arrays; + +import org.deeplearning4j.exception.DL4JInvalidInputException; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.gradient.DefaultGradient; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.BaseLayer; +import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +/** + * + * CDAE层 + * + *
+ * Collaborative Denoising Auto-Encoders for Top-N Recommender Systems
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class CDAELayer extends BaseLayer { + + public CDAELayer(NeuralNetConfiguration conf) { + super(conf); + } + + @Override + public INDArray preOutput(boolean training, LayerWorkspaceMgr workspaceMgr) { + assertInputSet(false); + applyDropOutIfNecessary(training, workspaceMgr); + INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr); + INDArray U = getParamWithNoise(CDAEParameter.USER_KEY, training, workspaceMgr); + INDArray b = getParamWithNoise(DefaultParamInitializer.BIAS_KEY, training, workspaceMgr); + + // Input validation: + if (input.rank() != 2 || input.columns() != W.rows()) { + if (input.rank() != 2) { + throw new DL4JInvalidInputException("Input that is not a matrix; expected matrix (rank 2), got rank " + input.rank() + " array with shape " + Arrays.toString(input.shape()) + ". Missing preprocessor or wrong input type? " + layerId()); + } + throw new DL4JInvalidInputException("Input size (" + input.columns() + " columns; shape = " + Arrays.toString(input.shape()) + ") is invalid: does not match layer input size (layer # inputs = " + W.size(0) + ") " + layerId()); + } + + INDArray ret = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.size(0), W.size(1)); + input.mmuli(W, ret); + ret.addi(U); + if (hasBias()) { + ret.addiRowVector(b); + } + + if (maskArray != null) { + applyMask(ret); + } + + return ret; + } + + @Override + public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { + assertInputSet(true); + // If this layer is layer L, then epsilon is (w^(L+1)*(d^(L+1))^T) (or + // equivalent) + INDArray z = preOutput(true, workspaceMgr); // Note: using preOutput(INDArray) can't be used as this does a setInput(input) + // and resets the 'appliedDropout' flag + // INDArray activationDerivative = + // Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf().getLayer().getActivationFunction(), + // z).derivative()); + // INDArray activationDerivative = + // conf().getLayer().getActivationFn().getGradient(z); + // INDArray delta = epsilon.muli(activationDerivative); + INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); // TODO handle activation function params + + if (maskArray != null) { + applyMask(delta); + } + + Gradient ret = new DefaultGradient(); + + INDArray weightGrad = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY); // f order + Nd4j.gemm(input, delta, weightGrad, true, false, 1.0, 0.0); + ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad); + + INDArray userWeightGrad = gradientViews.get(CDAEParameter.USER_KEY); + userWeightGrad.assign(delta); + ret.gradientForVariable().put(CDAEParameter.USER_KEY, userWeightGrad); + + if (hasBias()) { + INDArray biasGrad = gradientViews.get(DefaultParamInitializer.BIAS_KEY); + delta.sum(biasGrad, 0); // biasGrad is initialized/zeroed first + ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad); + } + + INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, true, workspaceMgr); + INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[] { W.size(0), delta.size(0) }, 'f'); + epsilonNext = W.mmuli(delta.transpose(), epsilonNext).transpose(); // W.mmul(delta.transpose()).transpose(); + + weightNoiseParams.clear(); + + epsilonNext = backpropDropOutIfPresent(epsilonNext); + return new Pair<>(ret, epsilonNext); + } + + @Override + public boolean isPretrainLayer() { + return false; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEModel.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEModel.java new file mode 100644 index 0000000..053d27a --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEModel.java @@ -0,0 +1,82 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Nesterovs; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.NeuralNetworkModel; + +/** + * + * CDAE推荐器 + * + *
+ * Collaborative Denoising Auto-Encoders for Top-N Recommender Systems
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class CDAEModel extends NeuralNetworkModel { + + /** + * the threshold to binarize the rating + */ + private double binarie; + + @Override + protected int getInputDimension() { + return itemSize; + } + + @Override + protected MultiLayerConfiguration getNetworkConfiguration() { + NeuralNetConfiguration.ListBuilder factory = new NeuralNetConfiguration.Builder().seed(6) + // .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) + // .gradientNormalizationThreshold(1.0) + .updater(new Nesterovs(learnRatio, momentum)).weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(weightRegularization).list(); + factory.layer(0, new CDAEConfiguration.Builder().nIn(inputDimension).nOut(hiddenDimension).activation(Activation.fromString(hiddenActivation)).setNumUsers(userSize).build()); + factory.layer(1, new OutputLayer.Builder().nIn(hiddenDimension).nOut(inputDimension).lossFunction(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.fromString(outputActivation)).build()); + factory.pretrain(false).backprop(true); + MultiLayerConfiguration configuration = factory.build(); + return configuration; + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + binarie = configuration.getFloat("recommender.binarize.threshold"); + // transform the sparse matrix to INDArray + // the sparse training matrix has been binarized + + int[] matrixShape = new int[] { userSize, itemSize }; + inputData = Nd4j.zeros(matrixShape); + for (MatrixScalar term : scoreMatrix) { + if (term.getValue() > binarie) { + inputData.putScalar(term.getRow(), term.getColumn(), 1D); + } + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(outputData.getFloat(userIndex, itemIndex)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEParameter.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEParameter.java new file mode 100644 index 0000000..94d2447 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEParameter.java @@ -0,0 +1,84 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import java.util.Map; + +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.distribution.Distributions; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.rng.distribution.Distribution; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; + +/** + * + * CDAE参数 + * + *
+ * Collaborative Denoising Auto-Encoders for Top-N Recommender Systems
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +class CDAEParameter extends DefaultParamInitializer { + + public static final String USER_KEY = "u"; + + public int numberOfUsers; + + public CDAEParameter(int numberOfUsers) { + this.numberOfUsers = numberOfUsers; + } + + @Override + public long numParams(NeuralNetConfiguration conf) { + FeedForwardLayer layerConf = (FeedForwardLayer) conf.getLayer(); + return super.numParams(conf) + numberOfUsers * layerConf.getNOut(); // add + // another + // user + // weight + // matrix + } + + private INDArray createUserWeightMatrix(NeuralNetConfiguration conf, INDArray weightParamView, boolean initializeParameters) { + FeedForwardLayer layerConf = (FeedForwardLayer) conf.getLayer(); + if (initializeParameters) { + Distribution dist = Distributions.createDistribution(layerConf.getDist()); + return createWeightMatrix(numberOfUsers, layerConf.getNOut(), layerConf.getWeightInit(), dist, weightParamView, true); + } else { + return createWeightMatrix(numberOfUsers, layerConf.getNOut(), null, null, weightParamView, false); + } + } + + @Override + public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { + Map params = super.init(conf, paramsView, initializeParams); + FeedForwardLayer layerConf = (FeedForwardLayer) conf.getLayer(); + long nIn = layerConf.getNIn(); + long nOut = layerConf.getNOut(); + long nWeightParams = nIn * nOut; + long nUserWeightParams = numberOfUsers * nOut; + INDArray userWeightView = paramsView.get(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.interval(nWeightParams + nOut, nWeightParams + nOut + nUserWeightParams) }); + params.put(USER_KEY, this.createUserWeightMatrix(conf, userWeightView, initializeParams)); + conf.addVariable(USER_KEY); + return params; + } + + @Override + public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + Map out = super.getGradientsFromFlattened(conf, gradientView); + FeedForwardLayer layerConf = (FeedForwardLayer) conf.getLayer(); + long nIn = layerConf.getNIn(); + long nOut = layerConf.getNOut(); + long nWeightParams = nIn * nOut; + long nUserWeightParams = numberOfUsers * nOut; + INDArray userWeightGradientView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nWeightParams + nOut, nWeightParams + nOut + nUserWeightParams)).reshape('f', numberOfUsers, nOut); + out.put(USER_KEY, userWeightGradientView); + return out; + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMInputConfiguration.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMInputConfiguration.java new file mode 100644 index 0000000..998a725 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMInputConfiguration.java @@ -0,0 +1,85 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import java.util.Collection; +import java.util.Map; + +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.conf.memory.MemoryReport; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * + * DeepFM输入配置 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class DeepFMInputConfiguration extends FeedForwardLayer { + + private int[] dimensionSizes; + + private DeepFMParameter deepFMParameter; + + @Override + public ParamInitializer initializer() { + return new DeepFMParameter(dimensionSizes); + } + + public DeepFMInputConfiguration(int[] dimensionSizes) { + this.deepFMParameter = new DeepFMParameter(dimensionSizes); + this.dimensionSizes = dimensionSizes; + } + + private DeepFMInputConfiguration(Builder builder) { + super(builder); + this.deepFMParameter = new DeepFMParameter(builder.dimensionSizes); + this.dimensionSizes = builder.dimensionSizes; + } + + @Override + public Layer instantiate(NeuralNetConfiguration configuration, Collection monitors, int layerIndex, INDArray parameters, boolean initialize) { + DeepFMInputLayer layer = new DeepFMInputLayer(configuration, dimensionSizes); + layer.setListeners(monitors); + layer.setIndex(layerIndex); + layer.setParamsViewArray(parameters); + Map table = initializer().init(configuration, parameters, initialize); + layer.setParamTable(table); + layer.setConf(configuration); + return layer; + } + + public static class Builder extends FeedForwardLayer.Builder { + + private int[] dimensionSizes; + + public Builder(int[] dimensionSizes) { + this.dimensionSizes = dimensionSizes; + } + + @Override + public DeepFMInputConfiguration build() { + return new DeepFMInputConfiguration(this); + } + + } + + @Override + public LayerMemoryReport getMemoryReport(InputType inputType) { + LayerMemoryReport.Builder builder = new LayerMemoryReport.Builder(layerName, DeepFMInputConfiguration.class, inputType, inputType); + builder.standardMemory(0, 0).workingMemory(0, 0, 0, 0).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS); + return builder.build(); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMInputLayer.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMInputLayer.java new file mode 100644 index 0000000..84c1daa --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMInputLayer.java @@ -0,0 +1,125 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.gradient.DefaultGradient; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.BaseLayer; +import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; + +/** + * + * DeepFM输入层 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class DeepFMInputLayer extends BaseLayer { + + private int[] dimensionSizes; + + public DeepFMInputLayer(NeuralNetConfiguration configuration, int[] dimensionSizes) { + super(configuration); + this.dimensionSizes = dimensionSizes; + } + + @Override + public INDArray preOutput(boolean training, LayerWorkspaceMgr workspaceMgr) { + assertInputSet(false); + applyDropOutIfNecessary(training, workspaceMgr); + INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr); + INDArray b = getParamWithNoise(DefaultParamInitializer.BIAS_KEY, training, workspaceMgr); + + INDArray ret = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.size(0), W.size(1)); + ret.assign(0F); + for (int row = 0; row < input.rows(); row++) { + for (int column = 0; column < W.columns(); column++) { + float value = 0F; + int cursor = 0; + for (int index = 0; index < input.columns(); index++) { + value += W.getFloat(cursor + input.getInt(row, index), column); + cursor += dimensionSizes[index]; + } + ret.put(row, column, value); + } + } + + if (hasBias()) { + ret.addiRowVector(b); + } + + if (maskArray != null) { + applyMask(ret); + } + + return ret; + } + + @Override + public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { + assertInputSet(true); + // If this layer is layer L, then epsilon is (w^(L+1)*(d^(L+1))^T) (or + // equivalent) + INDArray z = preOutput(true, workspaceMgr); // Note: using preOutput(INDArray) can't be used as this does a setInput(input) + // and resets the 'appliedDropout' flag + // INDArray activationDerivative = + // Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf().getLayer().getActivationFunction(), + // z).derivative()); + // INDArray activationDerivative = + // conf().getLayer().getActivationFn().getGradient(z); + // INDArray delta = epsilon.muli(activationDerivative); + INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); // TODO handle activation function params + + if (maskArray != null) { + applyMask(delta); + } + + Gradient ret = new DefaultGradient(); + + INDArray weightGrad = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY); // f order + weightGrad.assign(0F); + for (int index = 0; index < input.rows(); index++) { + for (int column = 0; column < delta.columns(); column++) { + int cursor = 0; + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + int point = cursor + input.getInt(index, dimension); + float value = weightGrad.getFloat(point, column); + value += delta.getFloat(index, column); + weightGrad.put(point, column, value); + cursor += dimensionSizes[dimension]; + } + } + } + ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad); + + if (hasBias()) { + INDArray biasGrad = gradientViews.get(DefaultParamInitializer.BIAS_KEY); + delta.sum(biasGrad, 0); // biasGrad is initialized/zeroed first + ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad); + } + + INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, true, workspaceMgr); + INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[] { W.size(0), delta.size(0) }, 'f'); + epsilonNext = W.mmuli(delta.transpose(), epsilonNext).transpose(); // W.mmul(delta.transpose()).transpose(); + + weightNoiseParams.clear(); + + epsilonNext = backpropDropOutIfPresent(epsilonNext); + return new Pair<>(ret, epsilonNext); + } + + @Override + public boolean isPretrainLayer() { + return false; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMModel.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMModel.java new file mode 100644 index 0000000..009a439 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMModel.java @@ -0,0 +1,333 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import java.util.Map.Entry; + +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.learning.config.Sgd; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.processor.DataSorter; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.data.processor.AllFeatureDataSorter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; +import com.jstarcraft.rns.model.EpocheModel; + +/** + * + * DeepFM推荐器 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class DeepFMModel extends EpocheModel { + + /** + * the learning rate of the optimization algorithm + */ + protected float learnRatio; + + /** + * the momentum of the optimization algorithm + */ + protected float momentum; + + /** + * the regularization coefficient of the weights in the neural network + */ + protected float weightRegularization; + + /** + * 所有维度的特征总数 + */ + private int numberOfFeatures; + + /** + * the data structure that stores the training data + */ + protected INDArray[] inputData; + + /** + * the data structure that stores the predicted data + */ + protected INDArray outputData; + + /** + * 计算图 + */ + protected ComputationGraph graph; + + protected int[] dimensionSizes; + + protected DataModule marker; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + learnRatio = configuration.getFloat("recommender.iterator.learnrate"); + momentum = configuration.getFloat("recommender.iterator.momentum"); + weightRegularization = configuration.getFloat("recommender.weight.regularization"); + this.marker = model; + + // TODO 此处需要重构,外部索引与内部索引的映射转换 + dimensionSizes = new int[model.getQualityOrder()]; + for (int orderIndex = 0, orderSize = model.getQualityOrder(); orderIndex < orderSize; orderIndex++) { + Entry> term = model.getOuterKeyValue(orderIndex); + dimensionSizes[model.getQualityInner(term.getValue().getKey())] = space.getQualityAttribute(term.getValue().getKey()).getSize(); + } + } + + /** + * 获取计算图配置 + * + * @param dimensionSizes + * @return + */ + protected ComputationGraphConfiguration getComputationGraphConfiguration(int[] dimensionSizes) { + NeuralNetConfiguration.Builder netBuilder = new NeuralNetConfiguration.Builder(); + // 设置随机种子 + netBuilder.seed(6); + netBuilder.weightInit(WeightInit.XAVIER_UNIFORM); + netBuilder.updater(new Sgd(learnRatio)).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); + netBuilder.l1(weightRegularization); + + GraphBuilder graphBuilder = netBuilder.graphBuilder(); + + // 构建离散域(SparseField)节点 + String[] inputVertexNames = new String[dimensionSizes.length]; + int[] inputVertexSizes = new int[dimensionSizes.length]; + for (int fieldIndex = 0; fieldIndex < dimensionSizes.length; fieldIndex++) { + inputVertexNames[fieldIndex] = "SparseField" + fieldIndex; + // 每个离散特征的输入数 + inputVertexSizes[fieldIndex] = dimensionSizes[fieldIndex]; + } + graphBuilder.addInputs(inputVertexNames); + + // 构建Embed节点 + // TODO 应该调整为配置项. + int numberOfFactors = 10; + // TODO Embed只支持输入的column为1. + String[] embedVertexNames = new String[dimensionSizes.length]; + for (int fieldIndex = 0; fieldIndex < dimensionSizes.length; fieldIndex++) { + embedVertexNames[fieldIndex] = "Embed" + fieldIndex; + graphBuilder.addLayer(embedVertexNames[fieldIndex], new EmbeddingLayer.Builder().nIn(inputVertexSizes[fieldIndex]).nOut(numberOfFactors).activation(Activation.IDENTITY).build(), inputVertexNames[fieldIndex]); + } + + // 构建因子分解机部分 + // 构建FM Plus节点(实际就是FM的输入) + numberOfFeatures = 0; + for (int fieldIndex = 0; fieldIndex < dimensionSizes.length; fieldIndex++) { + numberOfFeatures += inputVertexSizes[fieldIndex]; + } + // TODO 注意,由于EmbedLayer不支持与其它Layer共享输入,所以FM Plus节点构建自己的One Hot输入. + graphBuilder.addInputs("FMInputs"); + graphBuilder.addLayer("FMPlus", new DeepFMInputConfiguration.Builder(dimensionSizes).nOut(1).activation(Activation.IDENTITY).build(), "FMInputs"); + + // 构建FM Product节点 + // 注意:节点数量是(n*(n-1)/2)),n为Embed节点数量 + String[] productVertexNames = new String[dimensionSizes.length * (dimensionSizes.length - 1) / 2]; + int productIndex = 0; + for (int outterFieldIndex = 0; outterFieldIndex < dimensionSizes.length; outterFieldIndex++) { + for (int innerFieldIndex = outterFieldIndex + 1; innerFieldIndex < dimensionSizes.length; innerFieldIndex++) { + productVertexNames[productIndex] = "FMProduct" + outterFieldIndex + ":" + innerFieldIndex; + String left = embedVertexNames[outterFieldIndex]; + String right = embedVertexNames[innerFieldIndex]; + graphBuilder.addVertex(productVertexNames[productIndex], new DeepFMProductConfiguration(), left, right); + productIndex++; + } + } + + // 构建FM Sum节点(实际就是FM的输出) + String[] names = new String[productVertexNames.length + 1]; + System.arraycopy(productVertexNames, 0, names, 0, productVertexNames.length); + names[names.length - 1] = "FMPlus"; + graphBuilder.addVertex("FMOutput", new DeepFMSumConfiguration(), names); + + // 构建多层网络部分 + // 构建Net Input节点 + // TODO 调整为支持输入(连续域)Dense Field. + // TODO 应该调整为配置项. + int numberOfHiddens = 100; + graphBuilder.addLayer("NetInput", new DenseLayer.Builder().nIn(dimensionSizes.length * numberOfFactors).nOut(numberOfHiddens).activation(Activation.LEAKYRELU).build(), embedVertexNames); + + // TODO 应该调整为配置项. + int numberOfLayers = 5; + String currentLayer = "NetInput"; + for (int layerIndex = 0; layerIndex < numberOfLayers; layerIndex++) { + graphBuilder.addLayer("NetHidden" + layerIndex, new DenseLayer.Builder().nIn(numberOfHiddens).nOut(numberOfHiddens).activation(Activation.LEAKYRELU).build(), currentLayer); + currentLayer = "NetHidden" + layerIndex; + } + + // 构建Net Output节点 + graphBuilder.addVertex("NetOutput", new DeepFMSumConfiguration(), currentLayer); + + // 构建Deep Output节点 + graphBuilder.addLayer("DeepOutput", new DeepFMOutputConfiguration.Builder(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).nIn(2).nOut(1).build(), "FMOutput", "NetOutput"); + + graphBuilder.setOutputs("DeepOutput"); + ComputationGraphConfiguration configuration = graphBuilder.build(); + return configuration; + } + + @Override + protected void doPractice() { + DataSplitter splitter = new QualityFeatureDataSplitter(userDimension); + DataModule[] models = splitter.split(marker, userSize); + DataSorter sorter = new AllFeatureDataSorter(); + for (int index = 0; index < userSize; index++) { + models[index] = sorter.sort(models[index]); + } + + DataInstance instance; + + int[] positiveKeys = new int[dimensionSizes.length], negativeKeys = new int[dimensionSizes.length]; + + ComputationGraphConfiguration configuration = getComputationGraphConfiguration(dimensionSizes); + + graph = new ComputationGraph(configuration); + graph.init(); + + for (int iterationStep = 1; iterationStep <= epocheSize; iterationStep++) { + totalError = 0F; + + // TODO 应该调整为配置项. + int batchSize = 2000; + inputData = new INDArray[dimensionSizes.length + 1]; + inputData[dimensionSizes.length] = Nd4j.zeros(batchSize, dimensionSizes.length); + for (int index = 0; index < dimensionSizes.length; index++) { + inputData[index] = inputData[dimensionSizes.length].getColumn(index); + } + INDArray labelData = Nd4j.zeros(batchSize, 1); + + for (int batchIndex = 0; batchIndex < batchSize;) { + // 随机用户 + int userIndex = RandomUtility.randomInteger(userSize); + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0 || userVector.getElementSize() == itemSize) { + continue; + } + + DataModule module = models[userIndex]; + instance = module.getInstance(0); + // 获取正样本 + int positivePosition = RandomUtility.randomInteger(module.getSize()); + instance.setCursor(positivePosition); + for (int index = 0; index < positiveKeys.length; index++) { + positiveKeys[index] = instance.getQualityFeature(index); + } + + // 获取负样本 + int negativeItemIndex = RandomUtility.randomInteger(itemSize - userVector.getElementSize()); + for (int position = 0, size = userVector.getElementSize(); position < size; position++) { + if (negativeItemIndex >= userVector.getIndex(position)) { + negativeItemIndex++; + continue; + } + break; + } + // TODO 注意,此处为了故意制造负面特征. + int negativePosition = RandomUtility.randomInteger(module.getSize()); + instance.setCursor(negativePosition); + for (int index = 0; index < negativeKeys.length; index++) { + negativeKeys[index] = instance.getQualityFeature(index); + } + negativeKeys[itemDimension] = negativeItemIndex; + + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + // inputData[dimension].putScalar(batchIndex, 0, + // positiveKeys[dimension]); + inputData[dimensionSizes.length].putScalar(batchIndex, dimension, positiveKeys[dimension]); + } + labelData.put(batchIndex, 0, 1); + batchIndex++; + + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + // inputData[dimension].putScalar(batchIndex, 0, + // negativeKeys[dimension]); + inputData[dimensionSizes.length].putScalar(batchIndex, dimension, negativeKeys[dimension]); + } + labelData.put(batchIndex, 0, 0); + batchIndex++; + } + graph.setInputs(inputData); + graph.setLabels(labelData); + for (int iterationIndex = 0; iterationIndex < 100; iterationIndex++) { + graph.fit(); + } + + INDArray[] data = new INDArray[inputData.length]; + for (int index = 0; index < data.length; index++) { + data[index] = inputData[index].get(NDArrayIndex.interval(0, 10)); + } + System.out.println(graph.outputSingle(data)); + totalError = (float) graph.score(); + if (isConverged(iterationStep) && isConverged) { + break; + } + currentError = totalError; + } + + inputData[dimensionSizes.length] = Nd4j.zeros(userSize, dimensionSizes.length); + for (int index = 0; index < dimensionSizes.length; index++) { + inputData[index] = inputData[dimensionSizes.length].getColumn(index); + } + + for (int userIndex = 0; userIndex < userSize; userIndex++) { + DataModule model = models[userIndex]; + if (model.getSize() > 0) { + instance = model.getInstance(model.getSize() - 1); + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + if (dimension != itemDimension) { + int feature = instance.getQualityFeature(dimension); + // inputData[dimension].putScalar(userIndex, 0, + // keys[dimension]); + inputData[dimensionSizes.length].putScalar(userIndex, dimension, feature); + inputData[dimension].putScalar(userIndex, 0, feature); + } + } + } else { + inputData[dimensionSizes.length].putScalar(userIndex, userDimension, userIndex); + inputData[userDimension].putScalar(userIndex, 0, userIndex); + } + } + + outputData = Nd4j.zeros(userSize, itemSize); + + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + inputData[itemDimension].assign(itemIndex); + outputData.putColumn(itemIndex, graph.outputSingle(inputData)); + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = outputData.getFloat(userIndex, itemIndex); + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMOutputConfiguration.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMOutputConfiguration.java new file mode 100644 index 0000000..6f501c0 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMOutputConfiguration.java @@ -0,0 +1,72 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import java.util.Collection; +import java.util.Map; + +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; +import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; + +/** + * + * DeepFM输出配置 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class DeepFMOutputConfiguration extends BaseOutputLayer { + + DeepFMOutputConfiguration() { + } + + protected DeepFMOutputConfiguration(Builder builder) { + super(builder); + } + + @Override + public ParamInitializer initializer() { + return DefaultParamInitializer.getInstance(); + } + + @Override + public Layer instantiate(NeuralNetConfiguration configuration, Collection monitors, int layerIndex, INDArray parameters, boolean initialize) { + DeepFMOutputLayer layer = new DeepFMOutputLayer(configuration); + layer.setListeners(monitors); + layer.setIndex(layerIndex); + layer.setParamsViewArray(parameters); + Map table = initializer().init(configuration, parameters, initialize); + layer.setParamTable(table); + layer.setConf(configuration); + return layer; + } + + public static class Builder extends BaseOutputLayer.Builder { + + public Builder(LossFunction lossFunction) { + super.lossFunction(lossFunction); + } + + public Builder(ILossFunction lossFunction) { + this.lossFn = lossFunction; + } + + @Override + public DeepFMOutputConfiguration build() { + return new DeepFMOutputConfiguration(this); + } + + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMOutputLayer.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMOutputLayer.java new file mode 100644 index 0000000..6597ada --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMOutputLayer.java @@ -0,0 +1,80 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.gradient.DefaultGradient; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.BaseOutputLayer; +import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.primitives.Pair; + +/** + * + * DeepFM输出层 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class DeepFMOutputLayer extends BaseOutputLayer { + + public DeepFMOutputLayer(NeuralNetConfiguration configuration) { + super(configuration); + } + + private Pair getGradientsAndDelta(INDArray preOut, LayerWorkspaceMgr workspaceMgr) { + ILossFunction lossFunction = layerConf().getLossFn(); + INDArray labels2d = getLabels2d(workspaceMgr, ArrayType.BP_WORKING_MEM); + // INDArray delta = lossFunction.computeGradient(labels2d, preOut, + // layerConf().getActivationFunction(), maskArray); + INDArray delta = lossFunction.computeGradient(labels2d, preOut, layerConf().getActivationFn(), maskArray); + + Gradient gradient = new DefaultGradient(); + + INDArray weightGradView = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY); + Nd4j.gemm(input, delta, weightGradView, true, false, 1.0, 0.0); // Equivalent to: weightGradView.assign(input.transpose().mmul(delta)); + gradient.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGradView); + + if (hasBias()) { + INDArray biasGradView = gradientViews.get(DefaultParamInitializer.BIAS_KEY); + delta.sum(biasGradView, 0); // biasGradView is initialized/zeroed first in sum op + gradient.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGradView); + } + + delta = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta); + return new Pair<>(gradient, delta); + } + + @Override + public Pair backpropGradient(INDArray previous, LayerWorkspaceMgr workspaceMgr) { + assertInputSet(true); + Pair pair = getGradientsAndDelta(preOutput2d(true, workspaceMgr), workspaceMgr); // Returns Gradient and delta^(this), not Gradient and epsilon^(this-1) + INDArray delta = pair.getSecond(); + + INDArray w = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, true, workspaceMgr); + INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[] { w.size(0), delta.size(0) }, 'f'); + epsilonNext = w.mmuli(delta.transpose(), epsilonNext).transpose(); + + // Normally we would clear weightNoiseParams here - but we want to reuse them + // for forward + backward + score + // So this is instead done in MultiLayerNetwork/CompGraph backprop methods + + epsilonNext = backpropDropOutIfPresent(epsilonNext); + return new Pair<>(pair.getFirst(), epsilonNext); + } + + @Override + protected INDArray getLabels2d(LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) { + return labels; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMParameter.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMParameter.java new file mode 100644 index 0000000..b78b2c4 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMParameter.java @@ -0,0 +1,87 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.distribution.Distributions; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.rng.distribution.Distribution; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; + +/** + * + * DeepFM参数 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +class DeepFMParameter extends DefaultParamInitializer { + + private int numberOfFeatures; + + public int[] dimensionSizes; + + public DeepFMParameter(int... dimensionSizes) { + this.dimensionSizes = dimensionSizes; + this.numberOfFeatures = 0; + for (int dimensionSize : dimensionSizes) { + numberOfFeatures += dimensionSize; + } + } + + @Override + public long numParams(NeuralNetConfiguration configuration) { + FeedForwardLayer layerConfiguration = (FeedForwardLayer) configuration.getLayer(); + return numberOfFeatures * layerConfiguration.getNOut() + layerConfiguration.getNOut(); + } + + protected INDArray createWeightMatrix(NeuralNetConfiguration configuration, INDArray view, boolean initialize) { + FeedForwardLayer layerConfiguration = (FeedForwardLayer) configuration.getLayer(); + if (initialize) { + Distribution distribution = Distributions.createDistribution(layerConfiguration.getDist()); + return super.createWeightMatrix(numberOfFeatures, layerConfiguration.getNOut(), layerConfiguration.getWeightInit(), distribution, view, true); + } else { + return super.createWeightMatrix(numberOfFeatures, layerConfiguration.getNOut(), null, null, view, false); + } + } + + @Override + public Map init(NeuralNetConfiguration configuration, INDArray view, boolean initialize) { + Map parameters = Collections.synchronizedMap(new LinkedHashMap()); + FeedForwardLayer layerConfiguration = (FeedForwardLayer) configuration.getLayer(); + long numberOfOut = layerConfiguration.getNOut(); + long numberOfWeights = numberOfFeatures * numberOfOut; + INDArray weight = view.get(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.interval(0, numberOfWeights) }); + INDArray bias = view.get(NDArrayIndex.point(0), NDArrayIndex.interval(numberOfWeights, numberOfWeights + numberOfOut)); + + parameters.put(WEIGHT_KEY, this.createWeightMatrix(configuration, weight, initialize)); + parameters.put(BIAS_KEY, createBias(configuration, bias, initialize)); + configuration.addVariable(WEIGHT_KEY); + configuration.addVariable(BIAS_KEY); + return parameters; + } + + @Override + public Map getGradientsFromFlattened(NeuralNetConfiguration configuration, INDArray view) { + Map gradients = new LinkedHashMap<>(); + FeedForwardLayer layerConfiguration = (FeedForwardLayer) configuration.getLayer(); + long numberOfOut = layerConfiguration.getNOut(); + long numberOfWeights = numberOfFeatures * numberOfOut; + INDArray weight = view.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, numberOfWeights)).reshape('f', numberOfWeights, numberOfOut); + INDArray bias = view.get(NDArrayIndex.point(0), NDArrayIndex.interval(numberOfWeights, numberOfWeights + numberOfOut)); + gradients.put(WEIGHT_KEY, weight); + gradients.put(BIAS_KEY, bias); + return gradients; + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMProductConfiguration.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMProductConfiguration.java new file mode 100644 index 0000000..81014fa --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMProductConfiguration.java @@ -0,0 +1,74 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import org.deeplearning4j.nn.conf.graph.GraphVertex; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.conf.memory.MemoryReport; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * + * DeepFM Product配置 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class DeepFMProductConfiguration extends GraphVertex { + + public DeepFMProductConfiguration() { + } + + @Override + public DeepFMProductConfiguration clone() { + return new DeepFMProductConfiguration(); + } + + @Override + public boolean equals(Object other) { + return other instanceof DeepFMProductConfiguration; + } + + @Override + public int hashCode() { + return DeepFMProductConfiguration.class.hashCode(); + } + + @Override + public long numParams(boolean backprop) { + return 0; + } + + @Override + public int minVertexInputs() { + return 2; + } + + @Override + public int maxVertexInputs() { + return 2; + } + + @Override + public DeepFMProductVertex instantiate(ComputationGraph graph, String name, int vertexIndex, INDArray paramsView, boolean initializeParams) { + return new DeepFMProductVertex(graph, name, vertexIndex); + } + + @Override + public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { + return InputType.feedForward(1); + } + + @Override + public MemoryReport getMemoryReport(InputType... inputTypes) { + // No working memory in addition to output activations + return new LayerMemoryReport.Builder(null, DeepFMProductConfiguration.class, inputTypes[0], inputTypes[0]).standardMemory(0, 0).workingMemory(0, 0, 0, 0).cacheMemory(0, 0).build(); + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMProductVertex.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMProductVertex.java new file mode 100644 index 0000000..28727b2 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMProductVertex.java @@ -0,0 +1,141 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; +import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.Or; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +/** + * + * DeepFM Product节点 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class DeepFMProductVertex extends BaseGraphVertex { + + public DeepFMProductVertex(ComputationGraph graph, String name, int vertexIndex) { + this(graph, name, vertexIndex, null, null); + } + + public DeepFMProductVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices) { + super(graph, name, vertexIndex, inputVertices, outputVertices); + } + + @Override + public boolean hasLayer() { + return false; + } + + @Override + public boolean isOutputVertex() { + return false; + } + + @Override + public Layer getLayer() { + return null; + } + + @Override + public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { + if (!canDoForward()) { + throw new IllegalStateException("Cannot do forward pass: inputs not set"); + } + // inputs[index] => {batchSize, numberOfEmbeds} + INDArray left = inputs[0]; + INDArray right = inputs[1]; + long size = inputs[0].shape()[0]; + INDArray value = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, size); + // 求两个行向量的点积 + for (int index = 0; index < size; index++) { + INDArray product = left.getRow(index).mmul(right.getRow(index).transpose()); + value.put(index, product); + } + // outputs[index] => {batchSize, 1} + return Shape.newShapeNoCopy(value, new long[] { value.length(), 1L }, value.ordering() == 'f'); + } + + @Override + public Pair doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) { + if (!canDoBackward()) { + throw new IllegalStateException("Cannot do backward pass: errors not set"); + } + + // epsilons[index] => {batchSize, numberOfEmbeds} + INDArray[] epsilons = new INDArray[inputs.length]; + epsilons[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, inputs[0]); + epsilons[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, inputs[1]); + // epsilon => {batchSize, 1} + // inputs[index] => {batchSize, numberOfEmbeds} + // TODO 如何通过inputs[index]与epsilon求导epsilons[index] + INDArray left = inputs[0]; + INDArray right = inputs[1]; + for (int index = 0; index < epsilon.rows(); index++) { + epsilons[0].putRow(index, right.getRow(index).transpose().mmul(epsilon.getRow(index)).transpose()); + epsilons[1].putRow(index, left.getRow(index).transpose().mmul(epsilon.getRow(index)).transpose()); + } + return new Pair<>(null, epsilons); + } + + @Override + public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) { + if (backpropGradientsViewArray != null) + throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here"); + } + + @Override + public Pair feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) { + if (maskArrays == null) { + return new Pair<>(null, currentMaskState); + } + + // Most common case: all or none. + // If there's only *some* mask arrays: assume the others (missing) are + // equivalent to all 1s + // And for handling multiple masks: best strategy seems to be an OR + // operation + // i.e., output is 1 if any of the input are 1s + // Which means: if any masks are missing, output null (equivalent to no + // mask, or all steps present) + // Otherwise do an element-wise OR operation + + for (INDArray mask : maskArrays) { + if (mask == null) { + return new Pair<>(null, currentMaskState); + } + } + + // At this point: all present. Do OR operation + if (maskArrays.length == 1) { + return new Pair<>(maskArrays[0], currentMaskState); + } else { + INDArray mask = maskArrays[0].dup(maskArrays[0].ordering()); + for (int index = 1; index < maskArrays.length; index++) { + Nd4j.getExecutioner().exec(new Or(maskArrays[index], mask, mask)); + } + return new Pair<>(mask, currentMaskState); + } + } + + @Override + public String toString() { + return "DeepFMProductVertex(id=" + this.getVertexIndex() + ",name=\"" + this.getVertexName() + "\")"; + } + +} \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMSumConfiguration.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMSumConfiguration.java new file mode 100644 index 0000000..eaaa05d --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMSumConfiguration.java @@ -0,0 +1,74 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import org.deeplearning4j.nn.conf.graph.GraphVertex; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.conf.memory.MemoryReport; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * + * DeepFM Sum配置 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class DeepFMSumConfiguration extends GraphVertex { + + public DeepFMSumConfiguration() { + } + + @Override + public DeepFMSumConfiguration clone() { + return new DeepFMSumConfiguration(); + } + + @Override + public boolean equals(Object other) { + return other instanceof DeepFMSumConfiguration; + } + + @Override + public int hashCode() { + return DeepFMSumConfiguration.class.hashCode(); + } + + @Override + public long numParams(boolean backprop) { + return 0; + } + + @Override + public int minVertexInputs() { + return 2; + } + + @Override + public int maxVertexInputs() { + return Integer.MAX_VALUE; + } + + @Override + public DeepFMSumVertex instantiate(ComputationGraph graph, String name, int vertexIndex, INDArray paramsView, boolean initializeParams) { + return new DeepFMSumVertex(graph, name, vertexIndex); + } + + @Override + public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { + return InputType.feedForward(1); + } + + @Override + public MemoryReport getMemoryReport(InputType... inputTypes) { + // No working memory in addition to output activations + return new LayerMemoryReport.Builder(null, DeepFMSumConfiguration.class, inputTypes[0], inputTypes[0]).standardMemory(0, 0).workingMemory(0, 0, 0, 0).cacheMemory(0, 0).build(); + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMSumVertex.java b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMSumVertex.java new file mode 100644 index 0000000..18da1bb --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMSumVertex.java @@ -0,0 +1,133 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; +import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.Or; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +/** + * + * DeepFM Sum节点 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class DeepFMSumVertex extends BaseGraphVertex { + + public DeepFMSumVertex(ComputationGraph graph, String name, int vertexIndex) { + this(graph, name, vertexIndex, null, null); + } + + public DeepFMSumVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices) { + super(graph, name, vertexIndex, inputVertices, outputVertices); + } + + @Override + public boolean hasLayer() { + return false; + } + + @Override + public boolean isOutputVertex() { + return false; + } + + @Override + public Layer getLayer() { + return null; + } + + @Override + public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { + if (!canDoForward()) { + throw new IllegalStateException("Cannot do forward pass: inputs not set"); + } + + // inputs[index] => {batchSize, numberOfEmbeds} + INDArray output = inputs[0].sum(1); + // 求N个行向量的总和 + for (int index = 1; index < inputs.length; index++) { + output.addi(inputs[index].sum(1)); + } + // output => {batchSize, 1} + return workspaceMgr.dup(ArrayType.ACTIVATIONS, output); + } + + @Override + public Pair doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) { + if (!canDoBackward()) { + throw new IllegalStateException("Cannot do backward pass: errors not set"); + } + // epsilons[index] => {batchSize, numberOfEmbeds} + INDArray[] epsilons = new INDArray[inputs.length]; + // epsilon => {batchSize, 1} + // inputs[index] => {batchSize, numberOfEmbeds} + // TODO 如何通过inputs[index]与epsilon求导epsilons[index] + INDArray output = doForward(true, workspaceMgr); + for (int index = 0; index < inputs.length; index++) { + epsilons[index] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, inputs[index]); + epsilons[index].muliColumnVector(epsilon).diviColumnVector(output); + } + return new Pair<>(null, epsilons); + } + + @Override + public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) { + if (backpropGradientsViewArray != null) + throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here"); + } + + @Override + public Pair feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) { + if (maskArrays == null) { + return new Pair<>(null, currentMaskState); + } + + // Most common case: all or none. + // If there's only *some* mask arrays: assume the others (missing) are + // equivalent to all 1s + // And for handling multiple masks: best strategy seems to be an OR + // operation + // i.e., output is 1 if any of the input are 1s + // Which means: if any masks are missing, output null (equivalent to no + // mask, or all steps present) + // Otherwise do an element-wise OR operation + + for (INDArray mask : maskArrays) { + if (mask == null) { + return new Pair<>(null, currentMaskState); + } + } + + // At this point: all present. Do OR operation + if (maskArrays.length == 1) { + return new Pair<>(maskArrays[0], currentMaskState); + } else { + INDArray mask = maskArrays[0].dup(maskArrays[0].ordering()); + for (int index = 1; index < maskArrays.length; index++) { + Nd4j.getExecutioner().exec(new Or(maskArrays[index], mask, mask)); + } + return new Pair<>(mask, currentMaskState); + } + } + + @Override + public String toString() { + return "DeepFMActivationVertex(id=" + this.getVertexIndex() + ",name=\"" + this.getVertexName() + "\")"; + } + +} \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/rating/AutoRecLearner.java b/src/main/java/com/jstarcraft/rns/model/dl4j/rating/AutoRecLearner.java new file mode 100644 index 0000000..3ed84c9 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/rating/AutoRecLearner.java @@ -0,0 +1,93 @@ +package com.jstarcraft.rns.model.dl4j.rating; + +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.linalg.primitives.Pair; + +/** + * + * AutoRec学习器 + * + *
+ * AutoRec: Autoencoders Meet Collaborative Filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class AutoRecLearner implements ILossFunction { + + private INDArray maskData; + + public AutoRecLearner(INDArray maskData) { + this.maskData = maskData; + } + + private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray scoreArr; + INDArray output = activationFn.getActivation(preOutput.dup(), true); + INDArray yMinusyHat = Transforms.abs(labels.sub(output)); + scoreArr = yMinusyHat.mul(yMinusyHat); + scoreArr = scoreArr.mul(maskData); + + if (mask != null) { + scoreArr.muliColumnVector(mask); + } + return scoreArr; + } + + @Override + public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask); + double score = scoreArr.sumNumber().doubleValue(); + + if (average) { + score /= scoreArr.size(0); + } + + return score; + } + + @Override + public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask); + return scoreArr.sum(1); + } + + @Override + public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray output = activationFn.getActivation(preOutput.dup(), true); + INDArray yMinusyHat = labels.sub(output); + INDArray dldyhat = yMinusyHat.mul(-2); + + INDArray gradients = activationFn.backprop(preOutput.dup(), dldyhat).getFirst(); + gradients = gradients.mul(maskData); + // multiply with masks, always + if (mask != null) { + gradients.muliColumnVector(mask); + } + + return gradients; + } + + @Override + public Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + return new Pair<>(computeScore(labels, preOutput, activationFn, mask, average), computeGradient(labels, preOutput, activationFn, mask)); + } + + @Override + public String toString() { + return super.toString() + "AutoRecLossFunction"; + } + + @Override + public String name() { + // TODO Auto-generated method stub + return toString(); + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/dl4j/rating/AutoRecModel.java b/src/main/java/com/jstarcraft/rns/model/dl4j/rating/AutoRecModel.java new file mode 100644 index 0000000..c9b17ea --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/dl4j/rating/AutoRecModel.java @@ -0,0 +1,77 @@ +package com.jstarcraft.rns.model.dl4j.rating; + +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Nesterovs; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.NeuralNetworkModel; + +/** + * + * AutoRec学习器 + * + *
+ * AutoRec: Autoencoders Meet Collaborative Filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class AutoRecModel extends NeuralNetworkModel { + + /** + * the data structure that indicates which element in the user-item is non-zero + */ + private INDArray maskData; + + @Override + protected int getInputDimension() { + return userSize; + } + + @Override + protected MultiLayerConfiguration getNetworkConfiguration() { + NeuralNetConfiguration.ListBuilder factory = new NeuralNetConfiguration.Builder().seed(6).updater(new Nesterovs(learnRatio, momentum)).weightInit(WeightInit.XAVIER_UNIFORM).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(weightRegularization).list(); + factory.layer(0, new DenseLayer.Builder().nIn(inputDimension).nOut(hiddenDimension).activation(Activation.fromString(hiddenActivation)).build()); + factory.layer(1, new OutputLayer.Builder(new AutoRecLearner(maskData)).nIn(hiddenDimension).nOut(inputDimension).activation(Activation.fromString(outputActivation)).build()); + MultiLayerConfiguration configuration = factory.pretrain(false).backprop(true).build(); + return configuration; + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + // transform the sparse matrix to INDArray + int[] matrixShape = new int[] { itemSize, userSize }; + inputData = Nd4j.zeros(matrixShape); + maskData = Nd4j.zeros(matrixShape); + for (MatrixScalar term : scoreMatrix) { + if (term.getValue() > 0D) { + inputData.putScalar(term.getColumn(), term.getRow(), term.getValue()); + maskData.putScalar(term.getColumn(), term.getRow(), 1D); + } + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(outputData.getFloat(itemIndex, userIndex)); + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/ensemble/bandit/bandit.txt b/src/main/java/com/jstarcraft/rns/model/ensemble/bandit/bandit.txt new file mode 100644 index 0000000..adfef29 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/ensemble/bandit/bandit.txt @@ -0,0 +1,10 @@ +Bandit算法与推荐系统: +https://blog.csdn.net/heyc861221/article/details/80129310 + +https://github.com/wealthfront/thompson-sampling + +https://github.com/danisola/bandit + +https://github.com/johnmyleswhite/BanditsBook + +https://github.com/huazhengwang/BanditLib \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/ensemble/ensemble.txt b/src/main/java/com/jstarcraft/rns/model/ensemble/ensemble.txt new file mode 100644 index 0000000..17f9f2e --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/ensemble/ensemble.txt @@ -0,0 +1,153 @@ +集成学习综述: +http://rec-sys.net/plugin.php?id=freeaddon_pdf_preview:pdf&pid=697&aid=101&md5hash=2591efd25fffc39ab95630f1d5bc536b +集成学习分为异态集成学习和同态集成学习. +异态集成学习又分为叠加法(Stack Generalization)和元学习法(Meta Learning). +叠加法(Stack Generalization)实际就是Stacking +元学习法主要有仲裁法(arbiter)和合并法(combiner) +Bagging和Boosting属于合并法. + +使用集成算法的时候,可以按照一个基本准则: +如果场景中的model为同质且大多数model都是欠拟合(偏差大,方差小),则选择boosting. +如果场景中的model为异质且大多数model都是欠拟合(偏差大,方差小),则选择stacking +如果场景中的model为同质且大多数model都是过拟合(偏差小,方差大),则选择bagging. +如果场景中的model为异质且大多数model都是效果精准(偏差不高,方差也不高),则选择stacking. + +笔记︱集成学习Ensemble Learning与树模型、Bagging 和 Boosting、模型融合: +http://blog.csdn.net/sinat_26917383/article/details/54667077 + +机器学习-->集成学习-->Bagging,Boosting,Stacking +http://blog.csdn.net/Mr_tyting/article/details/72957853 + +集成学习 - bagging/boosting/stacking: +https://www.zhihu.com/question/29036379 + +【译文】集成学习三大法宝-bagging、boosting、stacking +https://zhuanlan.zhihu.com/p/36161812 + +混合推荐技术总结: +https://my.oschina.net/liangtee/blog/119106 + +深入浅出混合推荐技术: +https://arthur503.github.io/blog/2013/09/30/mixed-recommend-systems.html + +魅族推荐平台架构解析: +http://blog.csdn.net/tech_meizu/article/details/70207570 + +推荐系统中所使用的混合技术介绍: +https://www.52ml.net/318.html + +在业界实际部署时,解决此类常见问题的流行方法是采用三段式混合系统: +即Online-Nearline-Offline Recommendation(在线-近线-离线)三层混合机制. +其中Online系统直接面向用户,是一个高性能和高可用性的推荐服务,在这里通常会设计有缓存(Cache)系统,来处理热门的请求(Query)重复计算的问题. +而当Cache不命中的情况下,Online推荐运行一个运算简单可靠的算法,及时生成结果. +Online系统后是Nearline系统,这个系统部署在服务端,一方面会接收Online系统发过来的请求,将Online计算的一些缓存结果,采用更复杂的算法重新计算并更新后更新缓存. +另一方面Nearline是衔接Online和Offline系统的桥梁,因为Offline结果往往会挖掘长期的、海量的用户行为日志,消耗的资源大、挖掘周期长,但是Offline推荐系统计算所得的结果质量往往是最高的,这些结果会通过Nearline系统输送到线上,发挥作用. + +混合推荐系统是推荐系统的另一个研究热点,它是指将多种推荐技术进行混合相互弥补缺点,从而可以获得更好的推荐效果. +最常见的是将协同过滤技术和其他技术相结合,克服cold start的问题. + + 整体式策略: + +(1)特征组合型 - Feature Combination +将来自不同推荐数据源的特征组合起来,由另一种推荐技术采用. +一般会将协同过滤的信息作为增加的特征向量,然后在这增加的数据集上采用基于内容的推荐技术. +Feature Combination的混合方式使得系统不再仅仅考虑协同过滤的数据源,所以它降低了用户对项目评分数量的敏感度,相反的,它允许系统拥有项的内部相似信息,其对协同系统是不透明的. + +理解: +特征组合其实就是将不同数据源的数据通过预处理形成更丰富的特征集合,再将它们作为输入数据的一部分一起输入到推荐算法中. + +疑问: +1,按照此理解,特征组合型的策略,核心推荐算法一定类似content-based或者factorization machine之类能够处理特征的算法.对应到LibRec中就是TensorRecommender. +2,现在流行的基于标签的推荐,也是特征组合型的思想的泛化. + +(2)特征递增/特征扩充型 - Feature Augmentation +前一个推荐方法的输出作为后一个推荐方法的输入. +比如,你可以将聚类分析作为关联规则的预处理,首先对会话文件进行聚类,再针对每个聚类进行关联规则挖掘,得到不同聚类的关联规则.当一个访问会话获得后,首先计算该访问会话与各聚类的匹配值,确认其属于哪个聚类,再应用这个聚类对应的关联规则进行推荐. +这个类型和瀑布型的不同点在哪里呢? +在特征递增型中,第二种推荐方法使用的特征包括了第一种的输出. +而在瀑布型中,第二种推荐方法并没有使用第一种产生的任何等级排列的输出,其两种推荐方法的结果以一种优化的方式进行混合. + +理解: +Feature Augmentation其实就是在预处理阶段采用机器学习的其它方法,补全缺失的数据或者挖掘新的特征,再将它们作为输入数据的一部分一起输入到推荐算法中. + +疑问: +1,按照此理解,特征递增型的策略,核心推荐算法可以是只支持分数矩阵输入的itemknn或者userknn,也可以是能支持特征矩阵输入的content-based或者factorization machine. + +总结: +整体式的两种策略似乎强调的都是数据预处理(包括特征工程等),算法会根据具体的场景选择直接使用旧的算法,或者设计新的算法(通常是对旧方法执行少量的修改). +整体式的数据预处理似乎与现代的特征工程息息相关. +整体式的算法思想强调的都是改进旧的算法或者设计新的算法.(例如lambdaFM就是借鉴了BRP与FM思想,又例如SBPR借鉴了BPR与社交特征.) + +注意: +Feature Augmentation与Cascade的区别 +1,Feature Augmentation强调的是预处理阶段对数据的补全或者特征的挖掘.核心算法数量永远是1. +2,Cascade强调的是直接将上个算法的输出作为下个算法的输入.核心算法数量是>=2. + + 并行式策略: + +(3)合并/复合/交叉型 - Mix +同时采用多种推荐技术给出多种推荐结果,为用户提供参考. +比如,可以构建这样一个基于日志和缓存数据挖掘的个性化推荐系统,该系统首先通过挖掘日志和缓存数据构建用户多方面的兴趣模式,然后根据目标用户的短期访问历史与用户的长期兴趣模式进行匹配,采用基于内容的过滤算法,向用户推荐相似网页. +同时,通过对多用户间的协同过滤,为目标用户预测下一步最有可能的访问页面,并根据得分对页面进行排序,附在现行用户请求访问页面后推荐给用户.也就是“猜你喜欢可能感兴趣的网页”. + +理解: +Mix实际是一种相当简单的策略,就是将不同推荐算法/系统的推荐结果聚合在一起.各个算法与系统之间既无关联也互不影响. + +疑惑: +此种策略似乎有时候需要额外的筛选机制保证各个推荐算法的结果是相关的? + +(4)加权型 - Weight +就是将多种推荐技术的计算结果加权混合产生推荐. +最简单的方式是线性混合,首先将协同过滤的推荐结果和基于内容的推荐结果赋予相同的权重值,然后比较用户对项的评价与系统的预测是否相符,然后调整权重值. +加权型混合方式的特点是整个系统性能都直接与推荐过程相关,这样一来就很容易在这之后进行信任分配和调整相应的混合模型,不过这种技术有一个假设的前提是对于整个空间中所有可能的项,使用不同技术的相关参数值都基本相同. + +理解: +Weight实际是Mix的改进,就是将不同推荐算法/系统的推荐结果根据不同的权重(静态或者动态)调整以后得到最终的评分或者排序. + +疑惑: +此种策略似乎有时候需要额外的权重机制,根据推荐效果调整各个推荐算法/系统的权重值? + +(5)转换型 - Switch +根据问题背景和实际情况采用不同的推荐技术. +比如,使用基于内容推荐和协同过滤混合的方式,系统首先使用基于内容的推荐技术,如果它不能产生高可信度的推荐,然后再尝试使用协同过滤技术. +因为需要各种情况比较转换标准,所以这种方法会增加算法的复杂度和参数化,当然这样做的好处是对各种推荐技术的优点和弱点比较敏感. + +理解: +Switch实际是Weight的改进,根据系统状态(冷启动等)或者用户所处的特定情景(上下文)选择不同推荐算法/系统为其服务. + +疑惑: +此种策略似乎需要额外的裁定机制或者知识系统,决定在什么情况下使用哪种推荐算法/系统? + +总结: +并行式的三种策略似乎都需要额外的协调机制配合才能有效的工作. + + 串行式策略: + +(6)瀑布/层叠型 - Cascade +后一个推荐方法优化前一个推荐方法:它是一个分阶段的过程,首先用一种推荐技术产生一个较为粗略的候选结果,在此基础上使用第二种推荐技术对其作出进一步精确地推荐. +瀑布型允许系统对某些项避免采用低优先级的技术,这些项可能是通过第一种推荐技术被较好的予以区分了的,或者是很少被用户评价从来都不会被推荐的项目. +因为瀑布型的第二步,仅仅是集中在需要另外判断的项上.另外,瀑布型在低优先级技术上具有较高的容错性,因为高优先级得出的评分会变得更加精确,而不是被完全修改. + +理解: +Cascade实际是将上一个算法/系统的推荐列表作为下一个算法/系统的输入.(分数矩阵与特征矩阵依然可以和推荐列表一起作为下一个算法的输入) +上一个算法的推荐列表会限定下一个算法的推荐范围. +例如:基于知识的推荐列表通常是无序的或者得分相同的项大量存在,可以使用其它算法配合进一步优化推荐列表. + +注意: +下一个算法永远只能改变上个算法的推荐列表的得分或者数量. + +(7)元层次/分级型 - Meta Level +用一种推荐方法产生的模型作为另一种推荐方法的输入. +这个与特征递增型的不同在于: +在特征递增型中使用一个学习模型产生某些特征作为第二种算法的输入,而在元层次型中,整个模型都会作为输入. +比如,你可以通过组合基于用户的协同过滤和基于项目的协同过滤算法,先求解目标项目的相似项目集,在目标项目的相似项目集上再采用基于用户的协同过滤算法.这种基于相似项目的邻居用户协同推荐方法,能很好地处理用户多兴趣下的个性化推荐问题,尤其是候选推荐项目的内容属性相差很大的时候,该方法性能会更好. + +理解: +Meta Level实际上是将上一个机器学习算法的模型结果值作为下一个机器学习算法的模型的初始值. + +疑问: +1,按照此理解,各个算法/系统之间的模型至少在一定程度是相似的.例如:matrix factorization的各个算法都会有user factors与item factors,才能够实现Meta Level. +2,按照此理解,Meta Level能够组合的算法一定会限定在机器学习范畴. + +总结: +串行式的两种策略似乎都强调算法层面的组合,只要满足条件,各个算法之间可以随意增删或者改变顺序,不会影响到其它系统. \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/exception/ModelException.java b/src/main/java/com/jstarcraft/rns/model/exception/ModelException.java new file mode 100644 index 0000000..fb61e6d --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/exception/ModelException.java @@ -0,0 +1,25 @@ +package com.jstarcraft.rns.model.exception; + +/** + * 推荐异常 + * + * @author Birdy + * + */ +public class ModelException extends RuntimeException { + + private static final long serialVersionUID = 4072415788185880975L; + + public ModelException(String message) { + super(message); + } + + public ModelException(Throwable exception) { + super(exception); + } + + public ModelException(String message, Throwable exception) { + super(message, exception); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/extend/ranking/AssociationRuleModel.java b/src/main/java/com/jstarcraft/rns/model/extend/ranking/AssociationRuleModel.java new file mode 100644 index 0000000..92e1aa2 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/extend/ranking/AssociationRuleModel.java @@ -0,0 +1,133 @@ +package com.jstarcraft.rns.model.extend.ranking; + +import java.util.Iterator; +import java.util.concurrent.Semaphore; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.environment.EnvironmentContext; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.AbstractModel; +import com.jstarcraft.rns.model.exception.ModelException; + +/** + * + * Association Rule推荐器 + * + *
+ * A Recommendation Algorithm Using Multi-Level Association Rules
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class AssociationRuleModel extends AbstractModel { + + /** + * confidence matrix of association rules + */ + private DenseMatrix associationMatrix; + + /** + * setup + * + * @throws ModelException if error occurs + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + associationMatrix = DenseMatrix.valueOf(itemSize, itemSize); + } + + @Override + protected void doPractice() { + EnvironmentContext context = EnvironmentContext.getContext(); + Semaphore semaphore = new Semaphore(0); + // simple rule: X => Y, given that each user vector is regarded as a + // transaction + for (int leftItemIndex = 0; leftItemIndex < itemSize; leftItemIndex++) { + // all transactions for item itemIdx + SparseVector leftVector = scoreMatrix.getColumnVector(leftItemIndex); + for (int rightItemIndex = leftItemIndex + 1; rightItemIndex < itemSize; rightItemIndex++) { + SparseVector rightVector = scoreMatrix.getColumnVector(rightItemIndex); + int leftIndex = leftItemIndex; + int rightIndex = rightItemIndex; + context.doAlgorithmByAny(leftItemIndex * rightItemIndex, () -> { + int leftCursor = 0, rightCursor = 0, leftSize = leftVector.getElementSize(), rightSize = rightVector.getElementSize(); + if (leftSize != 0 && rightSize != 0) { + // compute confidence where containing item assoItemIdx + // among + // userRatingsVector + int count = 0; + Iterator leftIterator = leftVector.iterator(); + Iterator rightIterator = rightVector.iterator(); + VectorScalar leftTerm = leftIterator.next(); + VectorScalar rightTerm = rightIterator.next(); + // 判断两个有序数组中是否存在相同的数字 + while (leftCursor < leftSize && rightCursor < rightSize) { + if (leftTerm.getIndex() == rightTerm.getIndex()) { + count++; + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + leftCursor++; + rightCursor++; + } else if (leftTerm.getIndex() > rightTerm.getIndex()) { + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + rightCursor++; + } else if (leftTerm.getIndex() < rightTerm.getIndex()) { + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + leftCursor++; + } + } + float leftValue = (count + 0F) / leftVector.getElementSize(); + float rightValue = (count + 0F) / rightVector.getElementSize(); + associationMatrix.setValue(leftIndex, rightIndex, leftValue); + associationMatrix.setValue(rightIndex, leftIndex, rightValue); + } + semaphore.release(); + }); + } + try { + semaphore.acquire(itemSize - leftItemIndex - 1); + } catch (Exception exception) { + throw new ModelException(exception); + } + } + } + + /** + * predict a specific rating for user userIdx on item itemIdx. + * + * @param userIndex user index + * @param itemIndex item index + * @return predictive rating for user userIdx on item itemIdx + * @throws ModelException if error occurs + */ + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = 0F; + for (VectorScalar term : scoreMatrix.getRowVector(userIndex)) { + int associationIndex = term.getIndex(); + float association = associationMatrix.getValue(associationIndex, itemIndex); + double score = term.getValue(); + value += score * association; + } + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/extend/ranking/PRankDModel.java b/src/main/java/com/jstarcraft/rns/model/extend/ranking/PRankDModel.java new file mode 100644 index 0000000..480520e --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/extend/ranking/PRankDModel.java @@ -0,0 +1,140 @@ +package com.jstarcraft.rns.model.extend.ranking; + +import java.util.List; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.algorithm.correlation.MathCorrelation; +import com.jstarcraft.ai.math.structure.matrix.SymmetryMatrix; +import com.jstarcraft.ai.math.structure.vector.DenseVector; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.common.reflection.ReflectionUtility; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.collaborative.ranking.RankSGDModel; +import com.jstarcraft.rns.model.exception.ModelException; +import com.jstarcraft.rns.utility.SampleUtility; + +import it.unimi.dsi.fastutil.ints.IntSet; + +/** + * + * PRankD推荐器 + * + *
+ * Personalised ranking with diversity
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class PRankDModel extends RankSGDModel { + /** + * item importance + */ + private DenseVector itemWeights; + + /** + * item correlations + */ + private SymmetryMatrix itemCorrelations; + + /** + * similarity filter + */ + private float similarityFilter; + + /** + * initialization + * + * @throws ModelException if error occurs + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + similarityFilter = configuration.getFloat("recommender.sim.filter", 4F); + float denominator = 0F; + itemWeights = DenseVector.valueOf(itemSize); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + float numerator = scoreMatrix.getColumnScope(itemIndex); + denominator = denominator < numerator ? numerator : denominator; + itemWeights.setValue(itemIndex, numerator); + } + // compute item relative importance + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + itemWeights.setValue(itemIndex, itemWeights.getValue(itemIndex) / denominator); + } + + // compute item correlations by cosine similarity + // TODO 修改为配置枚举 + try { + Class correlationClass = (Class) Class.forName(configuration.getString("recommender.correlation.class")); + MathCorrelation correlation = ReflectionUtility.getInstance(correlationClass); + itemCorrelations = new SymmetryMatrix(scoreMatrix.getColumnSize()); + correlation.calculateCoefficients(scoreMatrix, true, itemCorrelations::setValue); + } catch (Exception exception) { + throw new RuntimeException(exception); + } + } + + /** + * train model + * + * @throws ModelException if error occurs + */ + @Override + protected void doPractice() { + List userItemSet = getUserItemSet(scoreMatrix); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + // for each rated user-item (u,i) pair + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0) { + continue; + } + IntSet itemSet = userItemSet.get(userIndex); + for (VectorScalar term : userVector) { + // each rated item i + int positiveItemIndex = term.getIndex(); + float positiveScore = term.getValue(); + int negativeItemIndex = -1; + do { + // draw an item j with probability proportional to + // popularity + negativeItemIndex = SampleUtility.binarySearch(itemProbabilities, 0, itemProbabilities.getElementSize() - 1, RandomUtility.randomFloat(itemProbabilities.getValue(itemProbabilities.getElementSize() - 1))); + // ensure that it is unrated by user u + } while (itemSet.contains(negativeItemIndex)); + float negativeScore = 0F; + // compute predictions + float positivePredict = predict(userIndex, positiveItemIndex), negativePredict = predict(userIndex, negativeItemIndex); + float distance = (float) Math.sqrt(1 - Math.tanh(itemCorrelations.getValue(positiveItemIndex, negativeItemIndex) * similarityFilter)); + float itemWeight = itemWeights.getValue(negativeItemIndex); + float error = itemWeight * (positivePredict - negativePredict - distance * (positiveScore - negativeScore)); + totalError += error * error; + + // update vectors + float learnFactor = learnRatio * error; + for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { + float userFactor = userFactors.getValue(userIndex, factorIndex); + float positiveItemFactor = itemFactors.getValue(positiveItemIndex, factorIndex); + float negativeItemFactor = itemFactors.getValue(negativeItemIndex, factorIndex); + userFactors.shiftValue(userIndex, factorIndex, -learnFactor * (positiveItemFactor - negativeItemFactor)); + itemFactors.shiftValue(positiveItemIndex, factorIndex, -learnFactor * userFactor); + itemFactors.shiftValue(negativeItemIndex, factorIndex, learnFactor * userFactor); + } + } + } + + totalError *= 0.5F; + if (isConverged(epocheIndex) && isConverged) { + break; + } + isLearned(epocheIndex); + currentError = totalError; + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/extend/rating/PersonalityDiagnosisModel.java b/src/main/java/com/jstarcraft/rns/model/extend/rating/PersonalityDiagnosisModel.java new file mode 100644 index 0000000..d9239c0 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/extend/rating/PersonalityDiagnosisModel.java @@ -0,0 +1,153 @@ +package com.jstarcraft.rns.model.extend.rating; + +import java.util.Iterator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.AbstractModel; +import com.jstarcraft.rns.model.exception.ModelException; + +import it.unimi.dsi.fastutil.floats.FloatArrayList; +import it.unimi.dsi.fastutil.floats.FloatList; +import it.unimi.dsi.fastutil.floats.FloatRBTreeSet; +import it.unimi.dsi.fastutil.floats.FloatSet; + +/** + * + * Personality Diagnosis推荐器 + * + *
+ * A brief introduction to Personality Diagnosis
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class PersonalityDiagnosisModel extends AbstractModel { + /** + * Gaussian noise: 2.5 suggested in the paper + */ + private float sigma; + + /** + * prior probability + */ + private float prior; + + private FloatList scores; + + /** + * initialization + * + * @throws ModelException if error occurs + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + prior = 1F / userSize; + sigma = configuration.getFloat("recommender.PersonalityDiagnosis.sigma"); + + FloatSet sorts = new FloatRBTreeSet(); + for (MatrixScalar term : scoreMatrix) { + sorts.add(term.getValue()); + } + sorts.remove(0F); + scores = new FloatArrayList(sorts); + } + + @Override + protected void doPractice() { + } + + /** + * predict a specific rating for user userIdx on item itemIdx. + * + * @param userIndex user index + * @param itemIndex item index + * @return predictive rating for user userIdx on item itemIdx + * @throws ModelException if error occurs + */ + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + int scoreSize = scores.size(); + float[] probabilities = new float[scoreSize]; + SparseVector itemVector = scoreMatrix.getColumnVector(itemIndex); + SparseVector rightUserVector = scoreMatrix.getRowVector(userIndex); + for (VectorScalar term : itemVector) { + // other users who rated item j + userIndex = term.getIndex(); + float score = term.getValue(); + float probability = 1F; + SparseVector leftUserVector = scoreMatrix.getRowVector(userIndex); + int leftCursor = 0, rightCursor = 0, leftSize = leftUserVector.getElementSize(), rightSize = rightUserVector.getElementSize(); + if (leftSize != 0 && rightSize != 0) { + Iterator leftIterator = leftUserVector.iterator(); + Iterator rightIterator = rightUserVector.iterator(); + VectorScalar leftTerm = leftIterator.next(); + VectorScalar rightTerm = rightIterator.next(); + // 判断两个有序数组中是否存在相同的数字 + while (leftCursor < leftSize && rightCursor < rightSize) { + if (leftTerm.getIndex() == rightTerm.getIndex()) { + probability *= gaussian(rightTerm.getValue(), leftTerm.getValue(), sigma); + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + leftCursor++; + rightCursor++; + } else if (leftTerm.getIndex() > rightTerm.getIndex()) { + if (rightIterator.hasNext()) { + rightTerm = rightIterator.next(); + } + rightCursor++; + } else if (leftTerm.getIndex() < rightTerm.getIndex()) { + if (leftIterator.hasNext()) { + leftTerm = leftIterator.next(); + } + leftCursor++; + } + } + } + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + probabilities[scoreIndex] += gaussian(scores.getFloat(scoreIndex), score, sigma) * probability; + } + } + for (int scoreIndex = 0; scoreIndex < scoreSize; scoreIndex++) { + probabilities[scoreIndex] *= prior; + } + int valueIndex = 0; + float probability = Float.MIN_VALUE; + for (int scoreIndex = 0; scoreIndex < probabilities.length; scoreIndex++) { + if (probabilities[scoreIndex] > probability) { + probability = probabilities[scoreIndex]; + valueIndex = scoreIndex; + } + } + instance.setQuantityMark(scores.get(valueIndex)); + } + + /** + * 非标准高斯实现 + * + * @param value + * @param mean + * @param standardDeviation + * @return + */ + private static float gaussian(float value, float mean, float standardDeviation) { + value = value - mean; + value = value / standardDeviation; + return (float) (Math.exp(-0.5F * value * value)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/extend/rating/SlopeOneModel.java b/src/main/java/com/jstarcraft/rns/model/extend/rating/SlopeOneModel.java new file mode 100644 index 0000000..293dec5 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/extend/rating/SlopeOneModel.java @@ -0,0 +1,103 @@ +package com.jstarcraft.rns.model.extend.rating; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.AbstractModel; +import com.jstarcraft.rns.model.exception.ModelException; + +/** + * + * Slope One推荐器 + * + *
+ * Slope One Predictors for Online Rating-Based Collaborative Filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class SlopeOneModel extends AbstractModel { + /** + * matrices for item-item differences with number of occurrences/cardinal + */ + private DenseMatrix deviationMatrix, cardinalMatrix; + + /** + * initialization + * + * @throws ModelException if error occurs + */ + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + deviationMatrix = DenseMatrix.valueOf(itemSize, itemSize); + cardinalMatrix = DenseMatrix.valueOf(itemSize, itemSize); + } + + /** + * train model + * + * @throws ModelException if error occurs + */ + @Override + protected void doPractice() { + // compute items' differences + for (int userIndex = 0; userIndex < userSize; userIndex++) { + SparseVector itemVector = scoreMatrix.getRowVector(userIndex); + for (VectorScalar leftTerm : itemVector) { + float leftScore = leftTerm.getValue(); + for (VectorScalar rightTerm : itemVector) { + if (leftTerm.getIndex() != rightTerm.getIndex()) { + float rightScore = rightTerm.getValue(); + deviationMatrix.shiftValue(leftTerm.getIndex(), rightTerm.getIndex(), leftScore - rightScore); + cardinalMatrix.shiftValue(leftTerm.getIndex(), rightTerm.getIndex(), 1); + } + } + } + } + + // normalize differences + deviationMatrix.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + float cardinal = cardinalMatrix.getValue(row, column); + scalar.setValue(cardinal > 0F ? value / cardinal : value); + }); + } + + /** + * predict a specific rating for user userIdx on item itemIdx. + * + * @param userIndex user index + * @param itemIndex item index + * @return predictive rating for user userIdx on item itemIdx + * @throws ModelException if error occurs + */ + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + float value = 0F, sum = 0F; + for (VectorScalar term : userVector) { + if (itemIndex == term.getIndex()) { + continue; + } + double cardinary = cardinalMatrix.getValue(itemIndex, term.getIndex()); + if (cardinary > 0F) { + value += (deviationMatrix.getValue(itemIndex, term.getIndex()) + term.getValue()) * cardinary; + sum += cardinary; + } + } + instance.setQuantityMark(sum > 0F ? value / sum : meanScore); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/model.txt b/src/main/java/com/jstarcraft/rns/model/model.txt new file mode 100644 index 0000000..7c2d05e --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/model.txt @@ -0,0 +1,46 @@ +总结的一个观点是:在两大类预测任务中。 +预测Ranking对于点击率有帮助,预测Rating对于满意度有帮助. +比如虽然人们经常会点击(排名)一些文章,但是实际读完文章却觉得没啥营养(评分). +虽然很多观点都认为Ranking比Rating重要,但是我认为完全是商业系统角度. + +例如: +对于电商和广告商,点击率是赚钱的命脉,所以它们会更偏重Ranking. +对于音乐商和电影商,它们跟注重口碑和满意度,所以它们会更偏重Rating. + +所以早期Netflix举办推荐比赛,才会以Rating作为基础. + +注意: +没有负样本不代表不能做推荐. +基于邻域的算法,比如基于Item的协同过滤(ItemCF)就可以在只有正样本的数据集推荐. +因为它的基本思想是在正样本集合外画个比正样本集合稍微大一点的圈,然后推荐给用户那些看过的视频相似的视频. +但没有负样本却代表机器学习算法基本无法工作. +因为机器学习算法基本是在正样本和负样本中间画一个分类面,没有负样本也就没有分类面. + +从数据能够直接获取正负样本的叫做显式(多类)数据. +只有正样本的叫隐式(单类)数据. +例如: +有数据(1,2,3,4,5) 能够明确知道喜欢-厌恶程度. +只有用户浏览过哪些商品 + +不平衡数据 +https://flyxu.github.io/2016/07/28/2016-7-28/ + +One Class Collaborative Filtering(OCCF)的思想就是要构造负样本. +单类协同过滤 +https://www.52ml.net/10234.html + +排序学习实践---RankNet方法: +根据不同类型的训练数据可以将排序学习方法分为以下三类:a)单点标注(point wise);b)两两标注(pair wise);c)列表标注(list wise). +https://yq.aliyun.com/articles/18 + +Learning to Rank入门小结 + 漫谈 +http://www.cnblogs.com/wentingtu/archive/2012/03/13/2393993.html + +学习排序 Learning to Rank 小结 +http://blog.csdn.net/nanjunxiao/article/details/8976195 + +随机采样和随机模拟:吉布斯采样Gibbs Sampling +http://blog.csdn.net/pipisorry/article/details/51373090 + +推荐系统遇上深度学习: +https://www.jianshu.com/c/e12d7195a9ff \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/neural network.txt b/src/main/java/com/jstarcraft/rns/model/neural network.txt new file mode 100644 index 0000000..abd913b --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/neural network.txt @@ -0,0 +1,5 @@ +浅谈神经网络中的bias: +http://www.cnblogs.com/shuaishuaidefeizhu/p/6832541.html + +Deep Learning(深度学习)学习笔记整理系列: +http://blog.csdn.net/zouxy09/article/details/8775360 \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/neuralnetwork/AutoRecLossFunction.java b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/AutoRecLossFunction.java new file mode 100644 index 0000000..3a6ed55 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/AutoRecLossFunction.java @@ -0,0 +1,88 @@ +package com.jstarcraft.rns.model.neuralnetwork; + +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.apache.commons.lang3.builder.HashCodeBuilder; +import org.nd4j.linalg.api.ndarray.INDArray; + +import com.jstarcraft.ai.math.structure.matrix.MathMatrix; +import com.jstarcraft.ai.math.structure.matrix.Nd4jMatrix; +import com.jstarcraft.ai.model.neuralnetwork.loss.LossFunction; + +/** + * + * AutoRec学习器 + * + *
+ * AutoRec: Autoencoders Meet Collaborative Filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class AutoRecLossFunction implements LossFunction { + + private Nd4jMatrix maskData; + + public AutoRecLossFunction(Nd4jMatrix maskData) { + this.maskData = maskData; + } + + @Override + public float computeScore(MathMatrix tests, MathMatrix trains, MathMatrix masks) { + float score = 0F; + if (tests instanceof Nd4jMatrix && trains instanceof Nd4jMatrix && maskData instanceof Nd4jMatrix) { + INDArray testArray = Nd4jMatrix.class.cast(tests).getArray(); + INDArray trainArray = Nd4jMatrix.class.cast(trains).getArray(); + INDArray scoreArray = trainArray.sub(testArray); + INDArray maskArray = Nd4jMatrix.class.cast(maskData).getArray(); + scoreArray.muli(scoreArray); + scoreArray.muli(maskArray); + score = scoreArray.sumNumber().floatValue(); + } + return score; + } + + @Override + public void computeGradient(MathMatrix tests, MathMatrix trains, MathMatrix masks, MathMatrix gradients) { + if (tests instanceof Nd4jMatrix && trains instanceof Nd4jMatrix && maskData instanceof Nd4jMatrix) { + INDArray testArray = Nd4jMatrix.class.cast(tests).getArray(); + INDArray trainArray = Nd4jMatrix.class.cast(trains).getArray(); + INDArray gradientArray = Nd4jMatrix.class.cast(gradients).getArray(); + INDArray maskArray = Nd4jMatrix.class.cast(maskData).getArray(); + trainArray.sub(testArray, gradientArray); + gradientArray.muli(2F); + gradientArray.muli(maskArray); + } + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } + if (object == null) { + return false; + } + if (getClass() != object.getClass()) { + return false; + } else { + AutoRecLossFunction that = (AutoRecLossFunction) object; + EqualsBuilder equal = new EqualsBuilder(); + equal.append(this.maskData, that.maskData); + return equal.isEquals(); + } + } + + @Override + public int hashCode() { + HashCodeBuilder hash = new HashCodeBuilder(); + hash.append(maskData); + return hash.toHashCode(); + } + + @Override + public String toString() { + return "AutoRecLossFunction(maskData=" + maskData + ")"; + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/neuralnetwork/AutoRecModel.java b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/AutoRecModel.java new file mode 100644 index 0000000..46f1af9 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/AutoRecModel.java @@ -0,0 +1,163 @@ +package com.jstarcraft.rns.model.neuralnetwork; + +import java.util.HashMap; +import java.util.Map; + +import org.nd4j.linalg.factory.Nd4j; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.MathCache; +import com.jstarcraft.ai.math.structure.Nd4jCache; +import com.jstarcraft.ai.math.structure.matrix.MathMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.Nd4jMatrix; +import com.jstarcraft.ai.model.neuralnetwork.Graph; +import com.jstarcraft.ai.model.neuralnetwork.GraphConfigurator; +import com.jstarcraft.ai.model.neuralnetwork.activation.IdentityActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.activation.SigmoidActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.layer.Layer; +import com.jstarcraft.ai.model.neuralnetwork.layer.ParameterConfigurator; +import com.jstarcraft.ai.model.neuralnetwork.layer.WeightLayer; +import com.jstarcraft.ai.model.neuralnetwork.learn.NesterovLearner; +import com.jstarcraft.ai.model.neuralnetwork.normalization.IgnoreNormalizer; +import com.jstarcraft.ai.model.neuralnetwork.optimization.StochasticGradientOptimizer; +import com.jstarcraft.ai.model.neuralnetwork.parameter.XavierUniformParameterFactory; +import com.jstarcraft.ai.model.neuralnetwork.schedule.ConstantSchedule; +import com.jstarcraft.ai.model.neuralnetwork.vertex.LayerVertex; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.EpocheModel; + +/** + * + * AutoRec学习器 + * + *
+ * AutoRec: Autoencoders Meet Collaborative Filtering
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class AutoRecModel extends EpocheModel { + + /** + * the dimension of input units + */ + protected int inputDimension; + + /** + * the dimension of hidden units + */ + protected int hiddenDimension; + + /** + * the activation function of the hidden layer in the neural network + */ + protected String hiddenActivation; + + /** + * the activation function of the output layer in the neural network + */ + protected String outputActivation; + + /** + * the learning rate of the optimization algorithm + */ + protected float learnRatio; + + /** + * the momentum of the optimization algorithm + */ + protected float momentum; + + /** + * the regularization coefficient of the weights in the neural network + */ + protected float weightRegularization; + + /** + * the data structure that stores the training data + */ + protected Nd4jMatrix inputData; + + /** + * the data structure that stores the predicted data + */ + protected Nd4jMatrix outputData; + + protected Graph network; + + /** + * the data structure that indicates which element in the user-item is non-zero + */ + private Nd4jMatrix maskData; + + protected int getInputDimension() { + return userSize; + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + inputDimension = getInputDimension(); + hiddenDimension = configuration.getInteger("recommender.hidden.dimension"); + hiddenActivation = configuration.getString("recommender.hidden.activation"); + outputActivation = configuration.getString("recommender.output.activation"); + learnRatio = configuration.getFloat("recommender.iterator.learnrate"); + momentum = configuration.getFloat("recommender.iterator.momentum"); + weightRegularization = configuration.getFloat("recommender.weight.regularization"); + + // transform the sparse matrix to INDArray + int[] matrixShape = new int[] { itemSize, userSize }; + inputData = new Nd4jMatrix(Nd4j.zeros(matrixShape)); + maskData = new Nd4jMatrix(Nd4j.zeros(matrixShape)); + outputData = new Nd4jMatrix(Nd4j.zeros(matrixShape)); + for (MatrixScalar term : scoreMatrix) { + if (term.getValue() > 0D) { + inputData.setValue(term.getColumn(), term.getRow(), term.getValue()); + maskData.setValue(term.getColumn(), term.getRow(), 1F); + } + } + } + + protected Graph getComputationGraph() { + GraphConfigurator configurator = new GraphConfigurator(); + Map configurators = new HashMap<>(); + Nd4j.getRandom().setSeed(6L); + ParameterConfigurator parameterConfigurator = new ParameterConfigurator(0F, weightRegularization, new XavierUniformParameterFactory()); + configurators.put(WeightLayer.WEIGHT_KEY, parameterConfigurator); + configurators.put(WeightLayer.BIAS_KEY, new ParameterConfigurator(0F, 0F)); + MathCache factory = new Nd4jCache(); + Layer cdaeLayer = new WeightLayer(inputDimension, hiddenDimension, factory, configurators, new SigmoidActivationFunction()); + Layer outputLayer = new WeightLayer(hiddenDimension, inputDimension, factory, configurators, new IdentityActivationFunction()); + + configurator.connect(new LayerVertex("input", factory, cdaeLayer, new NesterovLearner(new ConstantSchedule(learnRatio), new ConstantSchedule(momentum)), new IgnoreNormalizer())); + configurator.connect(new LayerVertex("output", factory, outputLayer, new NesterovLearner(new ConstantSchedule(learnRatio), new ConstantSchedule(momentum)), new IgnoreNormalizer()), "input"); + + Graph graph = new Graph(configurator, new StochasticGradientOptimizer(), new AutoRecLossFunction(maskData)); + return graph; + } + + @Override + protected void doPractice() { + Graph graph = getComputationGraph(); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = graph.practice(1, new MathMatrix[] { inputData }, new MathMatrix[] { inputData }); + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + graph.predict(new MathMatrix[] { inputData }, new MathMatrix[] { outputData }); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(outputData.getValue(itemIndex, userIndex)); + } +} diff --git a/src/main/java/com/jstarcraft/rns/model/neuralnetwork/CDAELayer.java b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/CDAELayer.java new file mode 100644 index 0000000..627d6c4 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/CDAELayer.java @@ -0,0 +1,182 @@ +package com.jstarcraft.rns.model.neuralnetwork; + +import java.util.Map; + +import org.apache.commons.math3.util.FastMath; +import org.nd4j.linalg.api.ndarray.INDArray; + +import com.jstarcraft.ai.math.structure.MathCache; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.MathMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.Nd4jMatrix; +import com.jstarcraft.ai.model.neuralnetwork.activation.ActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.layer.ParameterConfigurator; +import com.jstarcraft.ai.model.neuralnetwork.layer.WeightLayer; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.core.utility.StringUtility; + +/** + * + * CDAE层 + * + *
+ * Collaborative Denoising Auto-Encoders for Top-N Recommender Systems
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class CDAELayer extends WeightLayer { + + public final static String USER_KEY = "user"; + + private int numberOfUsers; + + public CDAELayer(int numberOfUsers, int numberOfInputs, int numberOfOutputs, MathCache factory, Map configurators, ActivationFunction function) { + super(numberOfInputs, numberOfOutputs, factory, configurators, function); + + this.numberOfUsers = numberOfUsers; + if (!this.configurators.containsKey(USER_KEY)) { + String message = StringUtility.format("参数{}配置缺失.", USER_KEY); + throw new IllegalArgumentException(message); + } + + MathMatrix userParameter = factory.makeMatrix(numberOfUsers, numberOfOutputs); + configurators.get(USER_KEY).getFactory().setValues(userParameter); + this.parameters.put(USER_KEY, userParameter); + MathMatrix userGradient = factory.makeMatrix(numberOfUsers, numberOfOutputs); + this.gradients.put(USER_KEY, userGradient); + } + + @Override + public float calculateL1Norm() { + float l1Sum = super.calculateL1Norm(); + + Float userRegularization = configurators.get(USER_KEY).getL1Regularization(); + MathMatrix userParameters = parameters.get(USER_KEY); + if (userRegularization != null && userParameters != null) { + if (userParameters instanceof Nd4jMatrix) { + INDArray array = Nd4jMatrix.class.cast(userParameters).getArray(); + float norm = array.norm1Number().floatValue(); + l1Sum += userRegularization * norm; + } else { + float norm = 0F; + for (MatrixScalar term : userParameters) { + norm += FastMath.abs(term.getValue()); + } + l1Sum += userRegularization * norm; + } + } + + return l1Sum; + } + + @Override + public float calculateL2Norm() { + float l2Sum = super.calculateL2Norm(); + + Float userRegularization = configurators.get(USER_KEY).getL2Regularization(); + MathMatrix userParameters = parameters.get(USER_KEY); + if (userRegularization != null && userParameters != null) { + if (userParameters instanceof Nd4jMatrix) { + INDArray array = Nd4jMatrix.class.cast(userParameters).getArray(); + float norm = array.norm2Number().floatValue(); + l2Sum += 0.5F * userRegularization * norm; + } else { + double norm = 0F; + for (MatrixScalar term : userParameters) { + norm += term.getValue() * term.getValue(); + } + l2Sum += 0.5F * userRegularization * norm; + } + } + + return l2Sum; + } + + @Override + public void doCache(MathCache factory, KeyValue samples) { + // 检查维度 + if (samples.getKey().getRowSize() != numberOfUsers) { + throw new IllegalArgumentException(); + } + + super.doCache(factory, samples); + } + + @Override + public void doForward() { + MathMatrix weightParameters = parameters.get(WEIGHT_KEY); + MathMatrix biasParameters = parameters.get(BIAS_KEY); + MathMatrix userParameters = parameters.get(USER_KEY); + + MathMatrix inputData = inputKeyValue.getKey(); + MathMatrix middleData = middleKeyValue.getKey(); + MathMatrix outputData = outputKeyValue.getKey(); + + middleData.dotProduct(inputData, false, weightParameters, false, MathCalculator.PARALLEL); + middleData.iterateElement(MathCalculator.PARALLEL, (scalar) -> { + int row = scalar.getRow(); + int column = scalar.getColumn(); + float value = scalar.getValue(); + scalar.setValue(value + userParameters.getValue(row, column)); + }); + if (biasParameters != null) { + for (int columnIndex = 0, columnSize = middleData.getColumnSize(); columnIndex < columnSize; columnIndex++) { + float bias = biasParameters.getValue(0, columnIndex); + middleData.getColumnVector(columnIndex).shiftValues(bias); + } + } + + function.forward(middleData, outputData); + + MathMatrix middleError = middleKeyValue.getValue(); + middleError.setValues(0F); + MathMatrix innerError = outputKeyValue.getValue(); + innerError.setValues(0F); + } + + @Override + public void doBackward() { + MathMatrix weightParameters = parameters.get(WEIGHT_KEY); + MathMatrix biasParameters = parameters.get(BIAS_KEY); + MathMatrix userParameters = parameters.get(USER_KEY); + MathMatrix weightGradients = gradients.get(WEIGHT_KEY); + MathMatrix biasGradients = gradients.get(BIAS_KEY); + MathMatrix userGradients = gradients.get(USER_KEY); + + MathMatrix inputData = inputKeyValue.getKey(); + MathMatrix middleData = middleKeyValue.getKey(); + MathMatrix outputData = outputKeyValue.getKey(); + + MathMatrix innerError = outputKeyValue.getValue(); + MathMatrix middleError = middleKeyValue.getValue(); + MathMatrix outerError = inputKeyValue.getValue(); + + // 计算梯度 + function.backward(middleData, innerError, middleError); + weightGradients.dotProduct(inputData, true, middleError, false, MathCalculator.PARALLEL); + userGradients.copyMatrix(middleError, false); + if (biasGradients != null) { + for (int columnIndex = 0, columnSize = biasGradients.getColumnSize(); columnIndex < columnSize; columnIndex++) { + float bias = middleError.getColumnVector(columnIndex).getSum(false); + biasGradients.setValue(0, columnIndex, bias); + } + } + + // weightParameters.doProduct(middleError.transpose()).transpose() + if (outerError != null) { + // TODO 使用累计的方式计算 + // TODO 需要锁机制,否则并发计算会导致Bug + outerError.accumulateProduct(middleError, false, weightParameters, true, MathCalculator.PARALLEL); + } + } + + @Override + public String toString() { + return "CDAELayer"; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/neuralnetwork/CDAEModel.java b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/CDAEModel.java new file mode 100644 index 0000000..a22c750 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/CDAEModel.java @@ -0,0 +1,179 @@ +package com.jstarcraft.rns.model.neuralnetwork; + +import java.util.HashMap; +import java.util.Map; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.math.structure.MathCache; +import com.jstarcraft.ai.math.structure.Nd4jCache; +import com.jstarcraft.ai.math.structure.matrix.MathMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.Nd4jMatrix; +import com.jstarcraft.ai.model.neuralnetwork.Graph; +import com.jstarcraft.ai.model.neuralnetwork.GraphConfigurator; +import com.jstarcraft.ai.model.neuralnetwork.activation.IdentityActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.activation.SigmoidActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.layer.Layer; +import com.jstarcraft.ai.model.neuralnetwork.layer.ParameterConfigurator; +import com.jstarcraft.ai.model.neuralnetwork.layer.WeightLayer; +import com.jstarcraft.ai.model.neuralnetwork.learn.NesterovLearner; +import com.jstarcraft.ai.model.neuralnetwork.loss.MSELossFunction; +import com.jstarcraft.ai.model.neuralnetwork.normalization.IgnoreNormalizer; +import com.jstarcraft.ai.model.neuralnetwork.optimization.StochasticGradientOptimizer; +import com.jstarcraft.ai.model.neuralnetwork.parameter.XavierParameterFactory; +import com.jstarcraft.ai.model.neuralnetwork.schedule.ConstantSchedule; +import com.jstarcraft.ai.model.neuralnetwork.vertex.LayerVertex; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.model.EpocheModel; + +/** + * + * CDAE推荐器 + * + *
+ * Collaborative Denoising Auto-Encoders for Top-N Recommender Systems
+ * 参考LibRec团队
+ * 
+ * + * @author Birdy + * + */ +public class CDAEModel extends EpocheModel { + + /** + * the dimension of input units + */ + protected int inputDimension; + + /** + * the dimension of hidden units + */ + protected int hiddenDimension; + + /** + * the activation function of the hidden layer in the neural network + */ + protected String hiddenActivation; + + /** + * the activation function of the output layer in the neural network + */ + protected String outputActivation; + + /** + * the learning rate of the optimization algorithm + */ + protected float learnRatio; + + /** + * the momentum of the optimization algorithm + */ + protected float momentum; + + /** + * the regularization coefficient of the weights in the neural network + */ + protected float weightRegularization; + + /** + * the data structure that stores the training data + */ + protected Nd4jMatrix inputData; + + protected Nd4jMatrix labelData; + + /** + * the data structure that stores the predicted data + */ + protected Nd4jMatrix outputData; + + protected Graph network; + + /** + * the threshold to binarize the rating + */ + private float binarie; + + protected int getInputDimension() { + return itemSize; + } + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + inputDimension = getInputDimension(); + hiddenDimension = configuration.getInteger("recommender.hidden.dimension"); + hiddenActivation = configuration.getString("recommender.hidden.activation"); + outputActivation = configuration.getString("recommender.output.activation"); + learnRatio = configuration.getFloat("recommender.iterator.learnrate"); + momentum = configuration.getFloat("recommender.iterator.momentum"); + weightRegularization = configuration.getFloat("recommender.weight.regularization"); + binarie = configuration.getFloat("recommender.binarize.threshold"); + // transform the sparse matrix to INDArray + // the sparse training matrix has been binarized + + INDArray array = Nd4j.create(userSize, itemSize); + inputData = new Nd4jMatrix(array); + + array = Nd4j.create(userSize, itemSize); + labelData = new Nd4jMatrix(array); + for (MatrixScalar term : scoreMatrix) { + labelData.setValue(term.getRow(), term.getColumn(), 1F); + } + + array = Nd4j.create(userSize, itemSize); + outputData = new Nd4jMatrix(array); + } + + protected Graph getComputationGraph() { + GraphConfigurator configurator = new GraphConfigurator(); + Map configurators = new HashMap<>(); + Nd4j.getRandom().setSeed(6L); + ParameterConfigurator parameterConfigurator = new ParameterConfigurator(0F, weightRegularization, new XavierParameterFactory()); + configurators.put(CDAELayer.WEIGHT_KEY, parameterConfigurator); + configurators.put(CDAELayer.BIAS_KEY, new ParameterConfigurator(0F, 0F)); + configurators.put(CDAELayer.USER_KEY, parameterConfigurator); + MathCache factory = new Nd4jCache(); + Layer cdaeLayer = new CDAELayer(userSize, itemSize, hiddenDimension, factory, configurators, new SigmoidActivationFunction()); + Layer outputLayer = new WeightLayer(hiddenDimension, itemSize, factory, configurators, new IdentityActivationFunction()); + + configurator.connect(new LayerVertex("cdae", factory, cdaeLayer, new NesterovLearner(new ConstantSchedule(momentum), new ConstantSchedule(learnRatio)), new IgnoreNormalizer())); + configurator.connect(new LayerVertex("output", factory, outputLayer, new NesterovLearner(new ConstantSchedule(momentum), new ConstantSchedule(learnRatio)), new IgnoreNormalizer()), "cdae"); + + Graph graph = new Graph(configurator, new StochasticGradientOptimizer(), new MSELossFunction()); + return graph; + } + + @Override + protected void doPractice() { + Graph graph = getComputationGraph(); + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + inputData.getArray().assign(labelData.getArray()); + for (MatrixScalar term : scoreMatrix) { + if (RandomUtility.randomFloat(1F) < 0.2F) { + inputData.setValue(term.getRow(), term.getColumn(), 0F); + } + } + totalError = graph.practice(1, new MathMatrix[] { inputData }, new MathMatrix[] { labelData }); + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + graph.predict(new MathMatrix[] { labelData }, new MathMatrix[] { outputData }); + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + instance.setQuantityMark(outputData.getValue(userIndex, itemIndex)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/neuralnetwork/DeepCrossModel.java b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/DeepCrossModel.java new file mode 100644 index 0000000..f2a7f07 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/DeepCrossModel.java @@ -0,0 +1,346 @@ +package com.jstarcraft.rns.model.neuralnetwork; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; + +import org.nd4j.linalg.factory.Nd4j; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.processor.DataSorter; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.ai.math.structure.DenseCache; +import com.jstarcraft.ai.math.structure.MathCache; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.model.neuralnetwork.Graph; +import com.jstarcraft.ai.model.neuralnetwork.GraphConfigurator; +import com.jstarcraft.ai.model.neuralnetwork.activation.IdentityActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.activation.ReLUActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.activation.SigmoidActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.layer.EmbedLayer; +import com.jstarcraft.ai.model.neuralnetwork.layer.Layer; +import com.jstarcraft.ai.model.neuralnetwork.layer.ParameterConfigurator; +import com.jstarcraft.ai.model.neuralnetwork.layer.WeightLayer; +import com.jstarcraft.ai.model.neuralnetwork.learn.SgdLearner; +import com.jstarcraft.ai.model.neuralnetwork.loss.BinaryXENTLossFunction; +import com.jstarcraft.ai.model.neuralnetwork.normalization.IgnoreNormalizer; +import com.jstarcraft.ai.model.neuralnetwork.optimization.StochasticGradientOptimizer; +import com.jstarcraft.ai.model.neuralnetwork.parameter.NormalParameterFactory; +import com.jstarcraft.ai.model.neuralnetwork.schedule.ConstantSchedule; +import com.jstarcraft.ai.model.neuralnetwork.schedule.Schedule; +import com.jstarcraft.ai.model.neuralnetwork.vertex.LayerVertex; +import com.jstarcraft.ai.model.neuralnetwork.vertex.ShareVertex; +import com.jstarcraft.ai.model.neuralnetwork.vertex.accumulation.OuterProductVertex; +import com.jstarcraft.ai.model.neuralnetwork.vertex.operation.PlusVertex; +import com.jstarcraft.ai.model.neuralnetwork.vertex.operation.ShiftVertex; +import com.jstarcraft.ai.model.neuralnetwork.vertex.transformation.HorizontalAttachVertex; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.data.processor.AllFeatureDataSorter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; +import com.jstarcraft.rns.model.EpocheModel; + +/** + * DCN推荐器 + * + *
+ * DCN——Deep & Cross Network for Ad Click Prediction
+ * 
+ */ +public class DeepCrossModel extends EpocheModel { + /** + * the learning rate of the optimization algorithm + */ + protected float learnRatio; + + /** + * the momentum of the optimization algorithm + */ + protected float momentum; + + /** + * the regularization coefficient of the weights in the neural network + */ + protected float weightRegularization; + + /** + * 所有维度的特征总数 + */ + private int numberOfFeatures; + + /** + * the data structure that stores the training data N个样本 f个filed + */ + protected DenseMatrix[] inputData; + + /** + * the data structure that stores the predicted data + */ + protected DenseMatrix outputData; + + /** + * 计算图 + */ + protected Graph graph; + + protected int[] dimensionSizes; + + protected DataModule marker; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + learnRatio = configuration.getFloat("recommender.iterator.learnrate"); + momentum = configuration.getFloat("recommender.iterator.momentum"); + weightRegularization = configuration.getFloat("recommender.weight.regularization"); + this.marker = model; + + // TODO 此处需要重构,外部索引与内部索引的映射转换 + dimensionSizes = new int[marker.getQualityOrder()]; + for (int orderIndex = 0, orderSize = marker.getQualityOrder(); orderIndex < orderSize; orderIndex++) { + Entry> term = marker.getOuterKeyValue(orderIndex); + dimensionSizes[marker.getQualityInner(term.getValue().getKey())] = space.getQualityAttribute(term.getValue().getKey()).getSize(); + } + } + + protected Graph getComputationGraph(int[] dimensionSizes) { + Schedule schedule = new ConstantSchedule(learnRatio); + GraphConfigurator configurator = new GraphConfigurator(); + Map configurators = new HashMap<>(); + Nd4j.getRandom().setSeed(6L); + ParameterConfigurator parameter = new ParameterConfigurator(weightRegularization, 0F, new NormalParameterFactory()); + configurators.put(WeightLayer.WEIGHT_KEY, parameter); + configurators.put(WeightLayer.BIAS_KEY, new ParameterConfigurator(0F, 0F)); + MathCache factory = new DenseCache(); + + // 构建Embed节点 + + int numberOfFactors = 10; + String[] embedVertexNames = new String[dimensionSizes.length]; + for (int fieldIndex = 0; fieldIndex < dimensionSizes.length; fieldIndex++) { + embedVertexNames[fieldIndex] = "Embed" + fieldIndex; + Layer embedLayer = new EmbedLayer(dimensionSizes[fieldIndex], numberOfFactors, factory, configurators, new IdentityActivationFunction()); + configurator.connect(new LayerVertex(embedVertexNames[fieldIndex], factory, embedLayer, new SgdLearner(schedule), new IgnoreNormalizer())); + } + + // 构建Net Input节点 + int numberOfHiddens = 20; + configurator.connect(new HorizontalAttachVertex("EmbedStack", factory), embedVertexNames); + configurator.connect(new ShiftVertex("EmbedStack0", factory, 0F), "EmbedStack"); + Layer netLayer = new WeightLayer(dimensionSizes.length * numberOfFactors, numberOfHiddens, factory, configurators, new ReLUActivationFunction()); + configurator.connect(new LayerVertex("NetInput", factory, netLayer, new SgdLearner(schedule), new IgnoreNormalizer()), "EmbedStack"); + + // cross net + // 构建crossNet + + int numberOfCrossLayers = 3; + + for (int crossLayerIndex = 0; crossLayerIndex < numberOfCrossLayers; crossLayerIndex++) { + if (crossLayerIndex == 0) { + configurator.connect(new OuterProductVertex("OuterProduct" + crossLayerIndex, factory), "EmbedStack0", "EmbedStack"); // (n,fk*fk) + } else { + configurator.connect(new OuterProductVertex("OuterProduct" + crossLayerIndex, factory), "EmbedStack" + crossLayerIndex, "EmbedStack"); // (n,fk*fk) + } + + // // 水平切割 + // String[] outerProductShare=new String[dimensionSizes.length * + // numberOfFactors]; + // for(int shareIndex=0;shareIndex= userVector.getIndex(position)) { + negativeItemIndex++; + continue; + } + break; + } + // TODO 注意,此处为了故意制造负面特征. + int negativePosition = RandomUtility.randomInteger(module.getSize()); + instance.setCursor(negativePosition); + for (int index = 0; index < negativeKeys.length; index++) { + negativeKeys[index] = instance.getQualityFeature(index); + } + negativeKeys[itemDimension] = negativeItemIndex; + + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + // inputData[dimension].putScalar(batchIndex, 0, + // positiveKeys[dimension]); + // inputData[dimensionSizes.length].setValue(batchIndex, dimension, + // positiveKeys[dimension]); + inputData[dimension].setValue(batchIndex, 0, positiveKeys[dimension]); + } + labelData.setValue(batchIndex, 0, 1); + batchIndex++; + + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + // inputData[dimension].putScalar(batchIndex, 0, + // negativeKeys[dimension]); + // inputData[dimensionSizes.length].setValue(batchIndex, dimension, + // negativeKeys[dimension]); + inputData[dimension].setValue(batchIndex, 0, negativeKeys[dimension]); + } + labelData.setValue(batchIndex, 0, 0); + batchIndex++; + } + totalError = graph.practice(100, inputData, new DenseMatrix[] { labelData }); + + DenseMatrix[] data = new DenseMatrix[inputData.length]; + DenseMatrix label = DenseMatrix.valueOf(10, 1); + for (int index = 0; index < data.length; index++) { + DenseMatrix input = inputData[index]; + data[index] = DenseMatrix.valueOf(10, input.getColumnSize()); + data[index].iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(input.getValue(scalar.getRow(), scalar.getColumn())); + }); + } + graph.predict(data, new DenseMatrix[] { label }); + System.out.println(label); + + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + + // inputData[dimensionSizes.length] = DenseMatrix.valueOf(numberOfUsers, + // dimensionSizes.length); + for (int index = 0; index < dimensionSizes.length; index++) { + inputData[index] = DenseMatrix.valueOf(userSize, 1); + } + + for (int userIndex = 0; userIndex < userSize; userIndex++) { + DataModule model = models[userIndex]; + if (model.getSize() > 0) { + instance = model.getInstance(model.getSize() - 1); + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + if (dimension != itemDimension) { + int feature = instance.getQualityFeature(dimension); + // inputData[dimension].putScalar(userIndex, 0, + // keys[dimension]); + // inputData[dimensionSizes.length].setValue(userIndex, dimension, feature); + inputData[dimension].setValue(userIndex, 0, feature); + } + } + } else { + // inputData[dimensionSizes.length].setValue(userIndex, userDimension, + // userIndex); + inputData[userDimension].setValue(userIndex, 0, userIndex); + } + } + + DenseMatrix labelData = DenseMatrix.valueOf(userSize, 1); + outputData = DenseMatrix.valueOf(userSize, itemSize); + + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + // inputData[dimensionSizes.length].getColumnVector(itemDimension).calculate(VectorMapper.constantOf(itemIndex), + // null, Calculator.SERIAL); + inputData[itemDimension].setValues(itemIndex); + graph.predict(inputData, new DenseMatrix[] { labelData }); + outputData.getColumnVector(itemIndex).iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(labelData.getValue(scalar.getIndex(), 0)); + }); + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = outputData.getValue(userIndex, itemIndex); + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/neuralnetwork/DeepFMModel.java b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/DeepFMModel.java new file mode 100644 index 0000000..e327527 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/DeepFMModel.java @@ -0,0 +1,336 @@ +package com.jstarcraft.rns.model.neuralnetwork; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; + +import org.nd4j.linalg.factory.Nd4j; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.processor.DataSorter; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.ai.math.structure.DenseCache; +import com.jstarcraft.ai.math.structure.MathCache; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.vector.SparseVector; +import com.jstarcraft.ai.model.neuralnetwork.Graph; +import com.jstarcraft.ai.model.neuralnetwork.GraphConfigurator; +import com.jstarcraft.ai.model.neuralnetwork.activation.IdentityActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.activation.ReLUActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.activation.SigmoidActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.layer.EmbedLayer; +import com.jstarcraft.ai.model.neuralnetwork.layer.Layer; +import com.jstarcraft.ai.model.neuralnetwork.layer.ParameterConfigurator; +import com.jstarcraft.ai.model.neuralnetwork.layer.WeightLayer; +import com.jstarcraft.ai.model.neuralnetwork.learn.SgdLearner; +import com.jstarcraft.ai.model.neuralnetwork.loss.BinaryXENTLossFunction; +import com.jstarcraft.ai.model.neuralnetwork.normalization.IgnoreNormalizer; +import com.jstarcraft.ai.model.neuralnetwork.optimization.StochasticGradientOptimizer; +import com.jstarcraft.ai.model.neuralnetwork.parameter.NormalParameterFactory; +import com.jstarcraft.ai.model.neuralnetwork.schedule.ConstantSchedule; +import com.jstarcraft.ai.model.neuralnetwork.schedule.Schedule; +import com.jstarcraft.ai.model.neuralnetwork.vertex.LayerVertex; +import com.jstarcraft.ai.model.neuralnetwork.vertex.accumulation.InnerProductVertex; +import com.jstarcraft.ai.model.neuralnetwork.vertex.transformation.HorizontalAttachVertex; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.data.processor.AllFeatureDataSorter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; +import com.jstarcraft.rns.model.EpocheModel; + +/** + * + * DeepFM推荐器 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +public class DeepFMModel extends EpocheModel { + + /** + * the learning rate of the optimization algorithm + */ + protected float learnRatio; + + /** + * the momentum of the optimization algorithm + */ + protected float momentum; + + /** + * the regularization coefficient of the weights in the neural network + */ + protected float weightRegularization; + + /** + * 所有维度的特征总数 + */ + private int numberOfFeatures; + + /** + * the data structure that stores the training data + */ + protected DenseMatrix[] inputData; + + /** + * the data structure that stores the predicted data + */ + protected DenseMatrix outputData; + + /** + * 计算图 + */ + protected Graph graph; + + protected int[] dimensionSizes; + + protected DataModule marker; + + @Override + public void prepare(Option configuration, DataModule model, DataSpace space) { + super.prepare(configuration, model, space); + learnRatio = configuration.getFloat("recommender.iterator.learnrate"); + momentum = configuration.getFloat("recommender.iterator.momentum"); + weightRegularization = configuration.getFloat("recommender.weight.regularization"); + this.marker = model; + + // TODO 此处需要重构,外部索引与内部索引的映射转换 + dimensionSizes = new int[model.getQualityOrder()]; + for (int orderIndex = 0, orderSize = model.getQualityOrder(); orderIndex < orderSize; orderIndex++) { + Entry> term = model.getOuterKeyValue(orderIndex); + dimensionSizes[model.getQualityInner(term.getValue().getKey())] = space.getQualityAttribute(term.getValue().getKey()).getSize(); + } + } + + protected Graph getComputationGraph(int[] dimensionSizes) { + Schedule schedule = new ConstantSchedule(learnRatio); + GraphConfigurator configurator = new GraphConfigurator(); + Map configurators = new HashMap<>(); + Nd4j.getRandom().setSeed(6L); + ParameterConfigurator parameter = new ParameterConfigurator(weightRegularization, 0F, new NormalParameterFactory()); + configurators.put(WeightLayer.WEIGHT_KEY, parameter); + configurators.put(WeightLayer.BIAS_KEY, new ParameterConfigurator(0F, 0F)); + MathCache factory = new DenseCache(); + + // 构建Embed节点 + // TODO 应该调整为配置项. + int numberOfFactors = 10; + // TODO Embed只支持输入的column为1. + String[] embedVertexNames = new String[dimensionSizes.length]; + for (int fieldIndex = 0; fieldIndex < dimensionSizes.length; fieldIndex++) { + embedVertexNames[fieldIndex] = "Embed" + fieldIndex; + Layer embedLayer = new EmbedLayer(dimensionSizes[fieldIndex], numberOfFactors, factory, configurators, new IdentityActivationFunction()); + configurator.connect(new LayerVertex(embedVertexNames[fieldIndex], factory, embedLayer, new SgdLearner(schedule), new IgnoreNormalizer())); + } + + // 构建因子分解机部分 + // 构建FM Plus节点(实际就是FM的输入) + numberOfFeatures = 0; + for (int fieldIndex = 0; fieldIndex < dimensionSizes.length; fieldIndex++) { + numberOfFeatures += dimensionSizes[fieldIndex]; + } + // TODO 注意,由于EmbedLayer不支持与其它Layer共享输入,所以FM Plus节点构建自己的One Hot输入. + Layer fmLayer = new FMLayer(dimensionSizes, numberOfFeatures, 1, factory, configurators, new IdentityActivationFunction()); + configurator.connect(new LayerVertex("FMPlus", factory, fmLayer, new SgdLearner(schedule), new IgnoreNormalizer())); + + // 构建FM Product节点 + // 注意:节点数量是(n*(n-1)/2)),n为Embed节点数量 + String[] productVertexNames = new String[dimensionSizes.length * (dimensionSizes.length - 1) / 2]; + int productIndex = 0; + for (int outterFieldIndex = 0; outterFieldIndex < dimensionSizes.length; outterFieldIndex++) { + for (int innerFieldIndex = outterFieldIndex + 1; innerFieldIndex < dimensionSizes.length; innerFieldIndex++) { + productVertexNames[productIndex] = "FMProduct" + outterFieldIndex + ":" + innerFieldIndex; + String left = embedVertexNames[outterFieldIndex]; + String right = embedVertexNames[innerFieldIndex]; + configurator.connect(new InnerProductVertex(productVertexNames[productIndex], factory), left, right); + productIndex++; + } + } + + // 构建FM Sum节点(实际就是FM的输出) + String[] names = new String[productVertexNames.length + 2]; + System.arraycopy(productVertexNames, 0, names, 0, productVertexNames.length); + names[productVertexNames.length] = "FMPlus"; + // configurator.connect(new SumVertex("FMOutput"), names); + + // 构建多层网络部分 + // 构建Net Input节点 + // TODO 调整为支持输入(连续域)Dense Field. + // TODO 应该调整为配置项. + int numberOfHiddens = 20; + configurator.connect(new HorizontalAttachVertex("EmbedStack", factory), embedVertexNames); + Layer netLayer = new WeightLayer(dimensionSizes.length * numberOfFactors, numberOfHiddens, factory, configurators, new ReLUActivationFunction()); + configurator.connect(new LayerVertex("NetInput", factory, netLayer, new SgdLearner(schedule), new IgnoreNormalizer()), "EmbedStack"); + + // TODO 应该调整为配置项. + int numberOfLayers = 5; + String currentLayer = "NetInput"; + for (int layerIndex = 0; layerIndex < numberOfLayers; layerIndex++) { + Layer hiddenLayer = new WeightLayer(numberOfHiddens, numberOfHiddens, factory, configurators, new ReLUActivationFunction()); + configurator.connect(new LayerVertex("NetHidden" + layerIndex, factory, hiddenLayer, new SgdLearner(schedule), new IgnoreNormalizer()), currentLayer); + currentLayer = "NetHidden" + layerIndex; + } + names[productVertexNames.length + 1] = currentLayer; + + // 构建Deep Output节点 + configurator.connect(new HorizontalAttachVertex("DeepStack", factory), names); + Layer deepLayer = new WeightLayer(productVertexNames.length + 1 + numberOfHiddens, 1, factory, configurators, new SigmoidActivationFunction()); + configurator.connect(new LayerVertex("DeepOutput", factory, deepLayer, new SgdLearner(schedule), new IgnoreNormalizer()), "DeepStack"); + + Graph graph = new Graph(configurator, new StochasticGradientOptimizer(), new BinaryXENTLossFunction(false)); + return graph; + } + + @Override + protected void doPractice() { + DataSplitter splitter = new QualityFeatureDataSplitter(userDimension); + DataModule[] models = splitter.split(marker, userSize); + DataSorter sorter = new AllFeatureDataSorter(); + for (int index = 0; index < userSize; index++) { + models[index] = sorter.sort(models[index]); + } + + DataInstance instance; + + int[] positiveKeys = new int[dimensionSizes.length], negativeKeys = new int[dimensionSizes.length]; + + graph = getComputationGraph(dimensionSizes); + + for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { + totalError = 0F; + + // TODO 应该调整为配置项. + int batchSize = 2000; + inputData = new DenseMatrix[dimensionSizes.length + 1]; + inputData[dimensionSizes.length] = DenseMatrix.valueOf(batchSize, dimensionSizes.length); + for (int index = 0; index < dimensionSizes.length; index++) { + inputData[index] = DenseMatrix.valueOf(batchSize, 1); + } + DenseMatrix labelData = DenseMatrix.valueOf(batchSize, 1); + + for (int batchIndex = 0; batchIndex < batchSize;) { + // 随机用户 + int userIndex = RandomUtility.randomInteger(userSize); + SparseVector userVector = scoreMatrix.getRowVector(userIndex); + if (userVector.getElementSize() == 0 || userVector.getElementSize() == itemSize) { + continue; + } + + DataModule module = models[userIndex]; + instance = module.getInstance(0); + // 获取正样本 + int positivePosition = RandomUtility.randomInteger(module.getSize()); + instance.setCursor(positivePosition); + for (int index = 0; index < positiveKeys.length; index++) { + positiveKeys[index] = instance.getQualityFeature(index); + } + + // 获取负样本 + int negativeItemIndex = RandomUtility.randomInteger(itemSize - userVector.getElementSize()); + for (int position = 0, size = userVector.getElementSize(); position < size; position++) { + if (negativeItemIndex >= userVector.getIndex(position)) { + negativeItemIndex++; + continue; + } + break; + } + // TODO 注意,此处为了故意制造负面特征. + int negativePosition = RandomUtility.randomInteger(module.getSize()); + instance.setCursor(negativePosition); + for (int index = 0; index < negativeKeys.length; index++) { + negativeKeys[index] = instance.getQualityFeature(index); + } + negativeKeys[itemDimension] = negativeItemIndex; + + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + // inputData[dimension].putScalar(batchIndex, 0, + // positiveKeys[dimension]); + inputData[dimensionSizes.length].setValue(batchIndex, dimension, positiveKeys[dimension]); + inputData[dimension].setValue(batchIndex, 0, positiveKeys[dimension]); + } + labelData.setValue(batchIndex, 0, 1); + batchIndex++; + + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + // inputData[dimension].putScalar(batchIndex, 0, + // negativeKeys[dimension]); + inputData[dimensionSizes.length].setValue(batchIndex, dimension, negativeKeys[dimension]); + inputData[dimension].setValue(batchIndex, 0, negativeKeys[dimension]); + } + labelData.setValue(batchIndex, 0, 0); + batchIndex++; + } + totalError = graph.practice(100, inputData, new DenseMatrix[] { labelData }); + + DenseMatrix[] data = new DenseMatrix[inputData.length]; + DenseMatrix label = DenseMatrix.valueOf(10, 1); + for (int index = 0; index < data.length; index++) { + DenseMatrix input = inputData[index]; + data[index] = DenseMatrix.valueOf(10, input.getColumnSize()); + data[index].iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(input.getValue(scalar.getRow(), scalar.getColumn())); + }); + } + graph.predict(data, new DenseMatrix[] { label }); + System.out.println(label); + + if (isConverged(epocheIndex) && isConverged) { + break; + } + currentError = totalError; + } + + inputData[dimensionSizes.length] = DenseMatrix.valueOf(userSize, dimensionSizes.length); + for (int index = 0; index < dimensionSizes.length; index++) { + inputData[index] = DenseMatrix.valueOf(userSize, 1); + } + + for (int userIndex = 0; userIndex < userSize; userIndex++) { + DataModule model = models[userIndex]; + if (model.getSize() > 0) { + instance = model.getInstance(model.getSize() - 1); + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + if (dimension != itemDimension) { + int feature = instance.getQualityFeature(dimension); + // inputData[dimension].putScalar(userIndex, 0, + // keys[dimension]); + inputData[dimensionSizes.length].setValue(userIndex, dimension, feature); + inputData[dimension].setValue(userIndex, 0, feature); + } + } + } else { + inputData[dimensionSizes.length].setValue(userIndex, userDimension, userIndex); + inputData[userDimension].setValue(userIndex, 0, userIndex); + } + } + + DenseMatrix labelData = DenseMatrix.valueOf(userSize, 1); + outputData = DenseMatrix.valueOf(userSize, itemSize); + + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + inputData[dimensionSizes.length].getColumnVector(itemDimension).setValues(itemIndex); + inputData[itemDimension].setValues(itemIndex); + graph.predict(inputData, new DenseMatrix[] { labelData }); + outputData.getColumnVector(itemIndex).iterateElement(MathCalculator.SERIAL, (scalar) -> { + scalar.setValue(labelData.getValue(scalar.getIndex(), 0)); + }); + } + } + + @Override + public void predict(DataInstance instance) { + int userIndex = instance.getQualityFeature(userDimension); + int itemIndex = instance.getQualityFeature(itemDimension); + float value = outputData.getValue(userIndex, itemIndex); + instance.setQuantityMark(value); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/neuralnetwork/FMLayer.java b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/FMLayer.java new file mode 100644 index 0000000..0cd6e53 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/FMLayer.java @@ -0,0 +1,137 @@ +package com.jstarcraft.rns.model.neuralnetwork; + +import java.util.Map; + +import com.jstarcraft.ai.math.structure.MathCache; +import com.jstarcraft.ai.math.structure.matrix.MathMatrix; +import com.jstarcraft.ai.model.neuralnetwork.activation.ActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.layer.ParameterConfigurator; +import com.jstarcraft.ai.model.neuralnetwork.layer.WeightLayer; +import com.jstarcraft.core.utility.KeyValue; + +/** + * + * FM层 + * + *
+ * DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
+ * 
+ * + * @author Birdy + * + */ +public class FMLayer extends WeightLayer { + + private int[] dimensionSizes; + + protected FMLayer() { + super(); + } + + public FMLayer(int[] dimensionSizes, int numberOfInputs, int numberOfOutputs, MathCache factory, Map configurators, ActivationFunction function) { + super(numberOfInputs, numberOfOutputs, factory, configurators, function); + this.dimensionSizes = dimensionSizes; + } + + @Override + public void doCache(MathCache factory, KeyValue samples) { + inputKeyValue = samples; + int rowSize = inputKeyValue.getKey().getRowSize(); + int columnSize = inputKeyValue.getKey().getColumnSize(); + + // 检查维度 + if (this.dimensionSizes.length != columnSize) { + throw new IllegalArgumentException(); + } + + middleKeyValue = new KeyValue<>(null, null); + outputKeyValue = new KeyValue<>(null, null); + + MathMatrix middleData = factory.makeMatrix(rowSize, numberOfOutputs); + middleKeyValue.setKey(middleData); + MathMatrix middleError = factory.makeMatrix(rowSize, numberOfOutputs); + middleKeyValue.setValue(middleError); + + MathMatrix outputData = factory.makeMatrix(rowSize, numberOfOutputs); + outputKeyValue.setKey(outputData); + MathMatrix innerError = factory.makeMatrix(rowSize, numberOfOutputs); + outputKeyValue.setValue(innerError); + } + + @Override + public void doForward() { + MathMatrix weightParameters = parameters.get(WEIGHT_KEY); + MathMatrix biasParameters = parameters.get(BIAS_KEY); + + MathMatrix inputData = inputKeyValue.getKey(); + MathMatrix middleData = middleKeyValue.getKey(); + MathMatrix outputData = outputKeyValue.getKey(); + + // inputData.dotProduct(weightParameters, middleData); + for (int row = 0; row < inputData.getRowSize(); row++) { + for (int column = 0; column < weightParameters.getColumnSize(); column++) { + float value = 0F; + int cursor = 0; + for (int index = 0; index < inputData.getColumnSize(); index++) { + value += weightParameters.getValue(cursor + (int) inputData.getValue(row, index), column); + cursor += dimensionSizes[index]; + } + middleData.setValue(row, column, value); + } + } + if (biasParameters != null) { + for (int columnIndex = 0, columnSize = middleData.getColumnSize(); columnIndex < columnSize; columnIndex++) { + float bias = biasParameters.getValue(0, columnIndex); + middleData.getColumnVector(columnIndex).shiftValues(bias); + } + } + + function.forward(middleData, outputData); + + MathMatrix middleError = middleKeyValue.getValue(); + middleError.setValues(0F); + + MathMatrix innerError = outputKeyValue.getValue(); + innerError.setValues(0F); + } + + @Override + public void doBackward() { + MathMatrix weightParameters = parameters.get(WEIGHT_KEY); + MathMatrix biasParameters = parameters.get(BIAS_KEY); + MathMatrix weightGradients = gradients.get(WEIGHT_KEY); + MathMatrix biasGradients = gradients.get(BIAS_KEY); + + MathMatrix innerError = outputKeyValue.getValue(); + MathMatrix middleError = middleKeyValue.getValue(); + // 必须为null + MathMatrix outerError = inputKeyValue.getValue(); + + MathMatrix inputData = inputKeyValue.getKey(); + MathMatrix middleData = middleKeyValue.getKey(); + MathMatrix outputData = outputKeyValue.getKey(); + + // 计算梯度 + function.backward(middleData, innerError, middleError); + + // inputData.transposeProductThat(middleError, weightGradients); + weightGradients.setValues(0F); + for (int index = 0; index < inputData.getRowSize(); index++) { + for (int column = 0; column < middleError.getColumnSize(); column++) { + int cursor = 0; + for (int dimension = 0; dimension < dimensionSizes.length; dimension++) { + int point = cursor + (int) inputData.getValue(index, dimension); + weightGradients.shiftValue(point, column, middleError.getValue(index, column)); + cursor += dimensionSizes[dimension]; + } + } + } + if (biasGradients != null) { + for (int columnIndex = 0, columnSize = biasGradients.getColumnSize(); columnIndex < columnSize; columnIndex++) { + float bias = middleError.getColumnVector(columnIndex).getSum(false); + biasGradients.setValue(0, columnIndex, bias); + } + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/model/neuralnetwork/neural.txt b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/neural.txt new file mode 100644 index 0000000..5d10481 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/neuralnetwork/neural.txt @@ -0,0 +1,11 @@ +YouTube深度神经网络推荐系统: +https://www.jianshu.com/p/c5b8268d273b + +4篇YouTube推荐系统论文, 一起来看看别人家的孩子: +https://medium.com/@yaoyaowd/4%E7%AF%87youtube%E6%8E%A8%E8%8D%90%E7%B3%BB%E7%BB%9F%E8%AE%BA%E6%96%87-%E4%B8%80%E8%B5%B7%E6%9D%A5%E7%9C%8B%E7%9C%8B%E5%88%AB%E4%BA%BA%E5%AE%B6%E7%9A%84%E5%AD%A9%E5%AD%90-b91279e03f83 + +Youtube 短视频推荐系统变迁:从机器学习到深度学习: +https://cloud.tencent.com/developer/article/1005439 + +《Deep Neural Networks for YouTube Recommendations》学习笔记: +https://blog.csdn.net/a819825294/article/details/71215538 \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/model/updateable.txt b/src/main/java/com/jstarcraft/rns/model/updateable.txt new file mode 100644 index 0000000..6bb5745 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/model/updateable.txt @@ -0,0 +1,2 @@ +sigir16-eals: +https://github.com/hexiangnan/sigir16-eals diff --git a/src/main/java/com/jstarcraft/rns/svd.txt b/src/main/java/com/jstarcraft/rns/svd.txt new file mode 100644 index 0000000..9815cd6 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/svd.txt @@ -0,0 +1,2 @@ +奇异值分解(SVD) --- 几何意义: +http://blog.sciencenet.cn/blog-696950-699432.html \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/utility/Digamma.java b/src/main/java/com/jstarcraft/rns/utility/Digamma.java new file mode 100644 index 0000000..f226b9f --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/utility/Digamma.java @@ -0,0 +1,99 @@ +package com.jstarcraft.rns.utility; + +/** + * Digamma + * + *
+ * http://www.psc.edu/~burkardt/src/dirichlet/dirichlet.f
+ * 
+ * + * @author Birdy + * + */ +class Digamma { + + static float large = 9.5F; + static float d1 = -0.5772156649015328606065121F; // digamma(1) + static float d2 = (float) (Math.pow(Math.PI, 2.0) / 6.0F); + static float small = 1E-6F; + static float s3 = 1F / 12F; + static float s4 = 1F / 120F; + static float s5 = 1F / 252F; + static float s6 = 1F / 240F; + static float s7 = 1F / 132F; + static float s8 = 691F / 32760F; + static float s9 = 1F / 12F; + static float s10 = 3617F / 8160F; + + public static float calculate(float x) { + float y = 0F; + float r = 0F; + + if (Float.isInfinite(x) || Float.isNaN(x)) { + return 0F / 0F; + } + + if (x == 0F) { + return -1F / 0F; + } + + if (x < 0F) { + // Use the reflection formula (Jeffrey 11.1.6): + // digamma(-x) = digamma(x+1) + pi*cot(pi*x) + y = (float) (Digamma.calculate(-x + 1) + Math.PI * (1F / Math.tan(-Math.PI * x))); + return y; + // This is related to the identity + // digamma(-x) = digamma(x+1) - digamma(z) + digamma(1-z) + // where z is the fractional part of x + // For example: + // digamma(-3.1) = 1/3.1 + 1/2.1 + 1/1.1 + 1/0.1 + digamma(1-0.1) + // = digamma(4.1) - digamma(0.1) + digamma(1-0.1) + // Then we use + // digamma(1-z) - digamma(z) = pi*cot(pi*z) + } + + // Use approximation if argument <= small. + if (x <= small) { + y = y + d1 - 1F / x + d2 * x; + return y; + } + + // Reduce to digamma(X + N) where (X + N) >= large. + while (true) { + if (x > small && x < large) { + y = y - 1F / x; + x = x + 1F; + } else { + break; + } + } + + // Use de Moivre's expansion if argument >= large. + // In maple: asympt(Psi(x), x); + if (x >= large) { + r = 1F / x; + y = (float) (y + Math.log(x) - 0.5F * r); + r = r * r; + y = y - r * (s3 - r * (s4 - r * (s5 - r * (s6 - r * s7)))); + } + + return y; + } + + // return the inverse function of digamma + // i.e., returns x such that digamma(x) = y + // adapted from Tony Minka fastfit Matlab code + public static float inverse(float y, int n) { + // Newton iteration to solve digamma(x)-y = 0 + float x = (float) (Math.exp(y) + 0.5F); + if (y <= -2.22F) { + x = -1F / (y - calculate(1)); + } + + for (int iter = 0; iter < n; iter++) { + x = x - (calculate(x) - y) / Trigamma.calculate(x); + } + return x; + } + +} \ No newline at end of file diff --git a/src/main/java/com/jstarcraft/rns/utility/GammaUtility.java b/src/main/java/com/jstarcraft/rns/utility/GammaUtility.java new file mode 100644 index 0000000..222c6b3 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/utility/GammaUtility.java @@ -0,0 +1,60 @@ +package com.jstarcraft.rns.utility; + +/** + * 伽玛工具 + * + * @author Birdy + * + */ +public class GammaUtility { + + /** + * 伽玛函数 + * + * @param value + * @return + */ + public static float gamma(float value) { + return (float) Math.exp(logGamma(value)); + } + + /** + * 伽玛函数的对数 + * + * @param value + * @return + */ + public static float logGamma(float value) { + if (value <= 0F) { + return Float.NaN; + } + float tmp = (float) ((value - 0.5F) * Math.log(value + 4.5F) - (value + 4.5F)); + float ser = 1F + 76.18009173F / (value + 0F) - 86.50532033F / (value + 1F) + 24.01409822F / (value + 2F) - 1.231739516F / (value + 3F) + 0.00120858003F / (value + 4F) - 0.00000536382F / (value + 5F); + return (float) (tmp + Math.log(ser * Math.sqrt(2F * Math.PI))); + } + + /** + * 伽玛函数的对数的一阶导数 + * + * @param value + * @return + */ + public static float digamma(float value) { + return Digamma.calculate(value); + } + + /** + * 伽玛函数的对数的二阶导数 + * + * @param value + * @return + */ + public static float trigamma(float value) { + return Trigamma.calculate(value); + } + + public static float inverse(float y, int n) { + return Digamma.inverse(y, n); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/utility/GaussianUtility.java b/src/main/java/com/jstarcraft/rns/utility/GaussianUtility.java new file mode 100644 index 0000000..6c3bdad --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/utility/GaussianUtility.java @@ -0,0 +1,77 @@ +package com.jstarcraft.rns.utility; + +import com.jstarcraft.ai.math.MathUtility; + +/** + * 高斯工具 + * + *
+ * https://en.wikipedia.org/wiki/Normal_distribution
+ * http://www.cnblogs.com/mengfanrong/p/4369545.html
+ * 
+ * + * @author Birdy + * + */ +public class GaussianUtility { + + public static float probabilityDensity(float value) { + return (float) (Math.exp(-0.5F * value * value) / Math.sqrt(2F * Math.PI)); + } + + public static float probabilityDensity(float value, float mean, float standardDeviation) { + return probabilityDensity((value - mean) / standardDeviation) / standardDeviation; + } + + public static float cumulativeDistribution(float value) { + if (value < -8F) { + return 0F; + } + if (value > 8F) { + return 1F; + } + float sum = 0F, term = value; + for (int index = 3; sum + term != sum; index += 2) { + sum = sum + term; + term = term * value * value / index; + } + return 0.5F + sum * probabilityDensity(value); + } + + public static float cumulativeDistribution(float value, float mean, float standardDeviation) { + return cumulativeDistribution((value - mean) / standardDeviation); + } + + public static float inverseDistribution(float value) { + return inverseDistribution(value, MathUtility.EPSILON, -8F, 8F); + } + + private static float inverseDistribution(float value, float delta, float minimum, float maximum) { + float median = minimum + (maximum - minimum) / 2F; + if (maximum - minimum < delta) { + return median; + } + if (cumulativeDistribution(median) > value) { + return inverseDistribution(value, delta, minimum, median); + } else { + return inverseDistribution(value, delta, median, maximum); + } + } + + public static float inverseDistribution(float value, float mean, float standardDeviation) { + return inverseDistribution(value, mean, standardDeviation, MathUtility.EPSILON, (mean - 8F * standardDeviation), (mean + 8F * standardDeviation)); + } + + private static float inverseDistribution(float value, float mean, float standardDeviation, float delta, float minimum, float maximum) { + float median = minimum + (maximum - minimum) / 2F; + if (maximum - minimum < delta) { + return median; + } + if (cumulativeDistribution(median, mean, standardDeviation) > value) { + return inverseDistribution(value, mean, standardDeviation, delta, minimum, median); + } else { + return inverseDistribution(value, mean, standardDeviation, delta, median, maximum); + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/utility/LogisticUtility.java b/src/main/java/com/jstarcraft/rns/utility/LogisticUtility.java new file mode 100644 index 0000000..235f766 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/utility/LogisticUtility.java @@ -0,0 +1,35 @@ +package com.jstarcraft.rns.utility; + +/** + * 似然工具 + * + * @author Birdy + * + */ +public class LogisticUtility { + + /** + * 获取似然值 + * + *
+     * logistic(x) + logistic(-x) = 1
+     * 
+ * + * @param value + * @return + */ + public static float getValue(float value) { + return (float) (1F / (1F + Math.exp(-value))); + } + + /** + * 获取似然梯度 + * + * @param value + * @return + */ + public static float getGradient(float value) { + return (float) (getValue(value) * getValue(-value)); + } + +} diff --git a/src/main/java/com/jstarcraft/rns/utility/SampleUtility.java b/src/main/java/com/jstarcraft/rns/utility/SampleUtility.java new file mode 100644 index 0000000..320d6f7 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/utility/SampleUtility.java @@ -0,0 +1,69 @@ +package com.jstarcraft.rns.utility; + +import com.jstarcraft.ai.math.structure.vector.MathVector; + +/** + * 采样工具 + * + * @author Birdy + * + */ +public class SampleUtility { + + /** + * 二分查找 + * + * @param values + * @param low + * @param high + * @param random + * @return + */ + public static int binarySearch(float[] values, int low, int high, float random) { + while (low < high) { + if (random < values[low]) { + return low; + } + if (random >= values[high - 1]) { + return high; + } + int middle = (low + high) / 2; + if (random < values[middle]) { + high = middle; + } else { + low = middle; + } + } + // throw new RecommendationException("概率范围超过随机范围,检查是否由于多线程修改导致."); + return -1; + } + + /** + * 二分查找 + * + * @param vector + * @param low + * @param high + * @param random + * @return + */ + public static int binarySearch(MathVector vector, int low, int high, float random) { + while (low < high) { + if (random < vector.getValue(low)) { + return low; + } + if (random >= vector.getValue(high - 1)) { + return high; + } + int middle = (low + high) / 2; + if (random < vector.getValue(middle)) { + high = middle; + } else { + low = middle; + } + } + // throw new RecommendationException("概率范围超过随机范围,检查是否由于多线程修改导致."); + return -1; + } + +} diff --git a/src/main/java/com/jstarcraft/rns/utility/SearchUtility.java b/src/main/java/com/jstarcraft/rns/utility/SearchUtility.java new file mode 100644 index 0000000..bbb73c9 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/utility/SearchUtility.java @@ -0,0 +1,192 @@ +package com.jstarcraft.rns.utility; + +import java.util.Iterator; +import java.util.concurrent.Semaphore; + +import com.jstarcraft.ai.environment.EnvironmentContext; +import com.jstarcraft.ai.math.structure.DefaultScalar; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.MathScalar; +import com.jstarcraft.ai.math.structure.matrix.MathMatrix; +import com.jstarcraft.ai.math.structure.vector.MathVector; +import com.jstarcraft.ai.math.structure.vector.VectorScalar; + +/** + * 搜索工具 + * + * @author Birdy + * + */ +public class SearchUtility { + + /** 阻尼系数 */ + private final static float defaultAlpha = 0.8F; + + /** 收敛系数 */ + private final static float defaultEpsilon = 0.001F; + + public static float[] pageRank(MathCalculator mode, int dimension, MathMatrix matrix) { + return pageRank(mode, dimension, matrix, defaultAlpha, defaultEpsilon); + } + + public static float[] pageRank(MathCalculator mode, int dimension, MathMatrix matrix, float alpha, float epsilon) { + // 随机性调整 + float stochasticity = 1F / dimension; + // 原始性调整 + float primitivity = (1F - alpha) * stochasticity; + + // 悬孤 + // TODO 考虑重构为int[],节省存储空间 + boolean[] ganglers = new boolean[dimension]; + for (int rowIndex = 0; rowIndex < dimension; rowIndex++) { + MathVector vector = matrix.getRowVector(rowIndex); + if (vector.getElementSize() == 0 || vector.getSum(false) == 0F) { + ganglers[rowIndex] = true; + } else { + vector.scaleValues(alpha); + vector.shiftValues(primitivity); + } + } + + switch (mode) { + case SERIAL: { + // 得分 + float[] scores = new float[dimension]; + for (int index = 0; index < dimension; index++) { + scores[index] = stochasticity; + } + // 判断是否收敛 + float error = 1F; + while (error >= epsilon) { + error = 0F; + for (int columnIndex = 0; columnIndex < dimension; columnIndex++) { + float score = 0F; + Iterator iterator = matrix.getColumnVector(columnIndex).iterator(); + VectorScalar scalar = null; + int index = -1; + float value = 0F; + if (iterator.hasNext()) { + scalar = iterator.next(); + index = scalar.getIndex(); + value = scalar.getValue(); + } + for (int rowIndex = 0; rowIndex < dimension; rowIndex++) { + if (index == rowIndex) { + // 判断是否为悬孤 + if (ganglers[rowIndex]) { + score += scores[rowIndex] * stochasticity; + } else { + score += scores[rowIndex] * value; + } + if (iterator.hasNext()) { + scalar = iterator.next(); + index = scalar.getIndex(); + value = scalar.getValue(); + } else { + scalar = null; + index = -1; + value = 0F; + } + } else { + // 判断是否为悬孤 + if (ganglers[rowIndex]) { + score += scores[rowIndex] * stochasticity; + } else { + score += scores[rowIndex] * primitivity; + } + } + } + error += Math.abs(score - scores[columnIndex]); + scores[columnIndex] = score; + } + } + return scores; + } + default: { + float[] scores = null; + // 得分 + float[] rowScores = new float[dimension]; + for (int index = 0; index < dimension; index++) { + rowScores[index] = stochasticity; + } + float[] columnScores = new float[dimension]; + for (int index = 0; index < dimension; index++) { + columnScores[index] = stochasticity; + } + // 判断是否收敛 + EnvironmentContext context = EnvironmentContext.getContext(); + Semaphore semaphore = MathCalculator.getSemaphore(); + MathScalar outerError = DefaultScalar.getInstance(); + outerError.setValue(1F); + while (outerError.getValue() >= epsilon) { + outerError.setValue(0F); + context.doAlgorithmByEvery(() -> { + MathScalar innerError = DefaultScalar.getInstance(); + innerError.setValue(0F); + }); + for (int columnIndex = 0; columnIndex < dimension; columnIndex++) { + int column = columnIndex; + float[] rowReference = rowScores; + float[] columnReference = columnScores; + context.doAlgorithmByAny(columnIndex, () -> { + float score = 0F; + Iterator iterator = matrix.getColumnVector(column).iterator(); + VectorScalar scalar = null; + int index = -1; + float value = 0F; + if (iterator.hasNext()) { + scalar = iterator.next(); + index = scalar.getIndex(); + value = scalar.getValue(); + } + for (int rowIndex = 0; rowIndex < dimension; rowIndex++) { + if (index == rowIndex) { + // 判断是否为悬孤 + if (ganglers[rowIndex]) { + score += rowReference[rowIndex] * stochasticity; + } else { + score += rowReference[rowIndex] * value; + } + if (iterator.hasNext()) { + scalar = iterator.next(); + index = scalar.getIndex(); + value = scalar.getValue(); + } else { + scalar = null; + index = -1; + value = 0F; + } + } else { + // 判断是否为悬孤 + if (ganglers[rowIndex]) { + score += rowReference[rowIndex] * stochasticity; + } else { + score += rowReference[rowIndex] * primitivity; + } + } + } + MathScalar innerError = DefaultScalar.getInstance(); + innerError.shiftValue(Math.abs(score - columnReference[column])); + columnReference[column] = score; + semaphore.release(); + }); + } + scores = columnScores; + columnScores = rowScores; + rowScores = scores; + try { + semaphore.acquire(dimension); + } catch (Exception exception) { + throw new RuntimeException(exception); + } + context.doAlgorithmByEvery(() -> { + MathScalar innerError = DefaultScalar.getInstance(); + outerError.shiftValue(innerError.getValue()); + }); + } + return scores; + } + } + } + +} diff --git a/src/main/java/com/jstarcraft/rns/utility/Trigamma.java b/src/main/java/com/jstarcraft/rns/utility/Trigamma.java new file mode 100644 index 0000000..61fc3d4 --- /dev/null +++ b/src/main/java/com/jstarcraft/rns/utility/Trigamma.java @@ -0,0 +1,70 @@ +package com.jstarcraft.rns.utility; + +/** + * Trigamma + * + *
+ * http://www.psc.edu/~burkardt/src/dirichlet/dirichlet.f
+ * 
+ * + * @author Birdy + * + */ +class Trigamma { + + static float small = 1E-4F; + static float large = 8F; + static float c = (float) (Math.pow(Math.PI, 2F) / 6F); + static float c1 = -2.404113806319188570799476F; + static float b2 = 1F / 6F; + static float b4 = -1F / 30F; + static float b6 = 1F / 42F; + static float b8 = -1F / 30F; + static float b10 = 5F / 66F; + + static float calculate(float x) { + float y = 0F; + float z = 0F; + + if (Float.isInfinite(x) || Float.isNaN(x)) { + return 0F / 0F; + } + + // zero or negative integer + if (x <= 0F && Math.floor(x) == x) { + return 1F / 0F; + } + + // Negative non-integer + if (x < 0 && Math.floor(x) != x) { + // Use the derivative of the digamma reflection formula: + // -trigamma(-x) = trigamma(x+1) - (pi*csc(pi*x))^2 + y = (float) (-Trigamma.calculate(-x + 1F) + Math.pow(Math.PI * (1F / Math.sin(-Math.PI * x)), 2F)); + return y; + } + + // Small value approximation + if (x <= small) { + y = 1F / (x * x) + c + c1 * x; + return y; + } + + // Reduce to trigamma(x+n) where ( X + N ) >= large. + while (true) { + if (x > small && x < large) { + y = y + 1F / (x * x); + x = x + 1F; + } else { + break; + } + } + + if (x >= large) { + z = 1F / (x * x); + y = y + 0.5F * z + (1F + z * (b2 + z * (b4 + z * (b6 + z * (b8 + z * b10))))) / x; + } + + return y; + } + +} \ No newline at end of file diff --git a/src/test/java/com/jstarcraft/rns/MockDataFactory.java b/src/test/java/com/jstarcraft/rns/MockDataFactory.java new file mode 100644 index 0000000..25eefa7 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/MockDataFactory.java @@ -0,0 +1,147 @@ +package com.jstarcraft.rns; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; + +import org.apache.commons.io.FileUtils; +import org.junit.Test; + +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.core.utility.StringUtility; + +/** + * 模拟数据工厂 + * + * @author Birdy + * + */ +public class MockDataFactory { + + private int userSize = 1000; + + private int itemSize = 1000; + + private int profileSize = 1000; + + private float profileScope = 1F; + + private float scoreScope = 5F; + + private int instantSize = 1000; + + private int locationSize = 180; + + private int commentSize = 1000; + + private float commentScope = 1F; + + private float ratio = 0.01F; + + /** + * user(离散:1:稠密)-profile(连续:n:稀疏) + * + *
+     * 可以当作user(离散:1:稠密)-user(离散:1:稠密)-degree(连续:1:稠密)
+     * 
+ * + * > + */ + @Test + public void mockUserProfile() throws Exception { + File file = new File("data/mock/user-profile"); + FileUtils.deleteQuietly(file); + file.getParentFile().mkdirs(); + file.createNewFile(); + StringBuilder buffer = new StringBuilder(); + try (FileWriter writer = new FileWriter(file); BufferedWriter out = new BufferedWriter(writer);) { + for (int leftIndex = 0; leftIndex < userSize; leftIndex++) { + buffer.setLength(0); + for (int rightIndex = 0; rightIndex < profileSize; rightIndex++) { + if (RandomUtility.randomFloat(1F) < ratio) { + float degree = RandomUtility.randomFloat(profileScope); + buffer.append(degree); + } + buffer.append(" "); + } + String profile = buffer.substring(0, buffer.length() - 1); + out.write(StringUtility.format("{} {}", leftIndex, profile)); + out.newLine(); + } + } + } + + /** + * item(离散:1:稠密)-profile(连续:n:稀疏) + * + *
+     * 可以当作item(离散:1:稠密)-item(离散:1:稠密)-degree(连续:1:稠密)
+     * 
+ */ + @Test + public void mockItemProfile() throws Exception { + File file = new File("data/mock/item-profile"); + FileUtils.deleteQuietly(file); + file.getParentFile().mkdirs(); + file.createNewFile(); + StringBuilder buffer = new StringBuilder(); + try (FileWriter writer = new FileWriter(file); BufferedWriter out = new BufferedWriter(writer);) { + for (int leftIndex = 0; leftIndex < itemSize; leftIndex++) { + buffer.setLength(0); + for (int rightIndex = 0; rightIndex < profileSize; rightIndex++) { + if (RandomUtility.randomFloat(1F) < ratio) { + float degree = RandomUtility.randomFloat(profileScope); + buffer.append(degree); + } + buffer.append(" "); + } + String profile = buffer.substring(0, buffer.length() - 1); + out.write(StringUtility.format("{} {}", leftIndex, profile)); + out.newLine(); + } + } + } + + /** + * user(离散:1:稠密)-item(离散:1:稠密)-score(连续:1:稠密)-instant(离散:1:稠密)-location(离散:2:稠密)-comment(连续:n:稀疏) + */ + @Test + public void mockUserItemScoreInstantLocationComment() throws Exception { + File file = new File("data/mock/user-item-score-instant-location-comment"); + FileUtils.deleteQuietly(file); + file.getParentFile().mkdirs(); + file.createNewFile(); + StringBuilder buffer = new StringBuilder(); + try (FileWriter writer = new FileWriter(file); BufferedWriter out = new BufferedWriter(writer);) { + for (int leftIndex = 0; leftIndex < userSize; leftIndex++) { + for (int rightIndex = 0; rightIndex < itemSize; rightIndex++) { + // 此处故意选择特定的数据(TODO 考虑改为利用正态分布) + if (rightIndex < 10 || RandomUtility.randomFloat(1F) < ratio) { + // 得分 + float score = RandomUtility.randomFloat(scoreScope); + // 时间 + int instant = RandomUtility.randomInteger(instantSize); + // 地点(经度) + int longitude = RandomUtility.randomInteger(locationSize); + // 地点(纬度) + int latitude = RandomUtility.randomInteger(locationSize); + buffer.setLength(0); + for (int commentIndex = 0; commentIndex < commentSize; commentIndex++) { + if (RandomUtility.randomFloat(1F) < ratio) { + float degree = RandomUtility.randomFloat(commentScope); + buffer.append(degree); + } + buffer.append(" "); + } + // 评论 + String comment = buffer.substring(0, buffer.length() - 1); + out.write(StringUtility.format("{} {} {} {} {} {} {}", leftIndex, rightIndex, score, instant, longitude, latitude, comment)); + out.newLine(); + } + + } + } + } + } + +} diff --git a/src/test/java/com/jstarcraft/rns/configure/ConfigurationTestCase.java b/src/test/java/com/jstarcraft/rns/configure/ConfigurationTestCase.java new file mode 100644 index 0000000..4f8abbe --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/configure/ConfigurationTestCase.java @@ -0,0 +1,30 @@ +package com.jstarcraft.rns.configure; + +import static org.junit.Assert.assertEquals; + +import java.lang.reflect.Type; +import java.util.LinkedList; + +import org.junit.Test; + +import com.jstarcraft.core.common.conversion.json.JsonUtility; +import com.jstarcraft.core.common.reflection.TypeUtility; +import com.jstarcraft.core.utility.KeyValue; + +public class ConfigurationTestCase { + + @Test + public void testConfigurationJson() { + LinkedList>> left = new LinkedList<>(); + left.add(new KeyValue<>("user", String.class)); + left.add(new KeyValue<>("item", int.class)); + left.add(new KeyValue<>("score", double.class)); + left.add(new KeyValue<>("word", String.class)); + String json = JsonUtility.object2String(left); + Type type = TypeUtility.parameterize(KeyValue.class, String.class, Class.class); + type = TypeUtility.parameterize(LinkedList.class, type); + LinkedList>> right = JsonUtility.string2Object(json, type); + assertEquals(left, right); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/data/converter/YongfengZhangAttributeHandler.java b/src/test/java/com/jstarcraft/rns/data/converter/YongfengZhangAttributeHandler.java new file mode 100644 index 0000000..f1b132e --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/converter/YongfengZhangAttributeHandler.java @@ -0,0 +1,88 @@ +package com.jstarcraft.rns.data.converter; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.StringReader; + +import org.xml.sax.Attributes; +import org.xml.sax.SAXException; +import org.xml.sax.helpers.DefaultHandler; + +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.attribute.QualityAttribute; +import com.jstarcraft.ai.data.attribute.QuantityAttribute; +import com.jstarcraft.ai.data.exception.DataException; + +public class YongfengZhangAttributeHandler extends DefaultHandler { + + private DataSpace space; + + private QualityAttribute userAttribute; + + private QualityAttribute itemAttribute; + + private QualityAttribute wordAttribute; + + private QuantityAttribute scoreAttribute; + + private QuantityAttribute sentimentAttribute; + + private StringBuffer buffer = new StringBuffer(); + + YongfengZhangAttributeHandler(DataSpace space) { + this.space = space; + this.userAttribute = space.getQualityAttribute("user"); + this.itemAttribute = space.getQualityAttribute("item"); + this.wordAttribute = space.getQualityAttribute("word"); + this.scoreAttribute = space.getQuantityAttribute("score"); + this.sentimentAttribute = space.getQuantityAttribute("sentiment"); + } + + private void parseData(BufferedReader buffer) throws IOException { + String line = buffer.readLine(); + line = buffer.readLine(); + String[] strings = line.split("\t"); + String user = strings[0]; + userAttribute.convertData(user); + String item = strings[1]; + itemAttribute.convertData(item); + Float score = Float.valueOf(strings[3]); + scoreAttribute.convertData(score); + + line = buffer.readLine(); + line = buffer.readLine(); + strings = line.split("\t"); + for (String string : strings) { + string = string.substring(1, string.length() - 1); + String[] elements = string.split(", "); + String word = elements[0]; + wordAttribute.convertData(word); + Float sentiment = elements[4].equalsIgnoreCase("N") ? Float.valueOf(elements[2]) : -Float.valueOf(elements[2]); + sentiment *= Float.valueOf(elements[3]); + sentimentAttribute.convertData(sentiment); + } + } + + @Override + public void startElement(String uri, String localName, String name, Attributes attributes) throws SAXException { + buffer.setLength(0); + } + + @Override + public void endElement(String uri, String localName, String name) throws SAXException { + try { + try (StringReader reader = new StringReader(buffer.toString()); BufferedReader stream = new BufferedReader(reader)) { + parseData(stream); + } + } catch (Exception exception) { + // TODO 处理日志. + throw new DataException(exception); + } + } + + @Override + public void characters(char characters[], int index, int length) throws SAXException { + buffer.append(characters, index, length); + } + +} \ No newline at end of file diff --git a/src/test/java/com/jstarcraft/rns/data/converter/YongfengZhangDatasetConverter.java b/src/test/java/com/jstarcraft/rns/data/converter/YongfengZhangDatasetConverter.java new file mode 100644 index 0000000..aec0c11 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/converter/YongfengZhangDatasetConverter.java @@ -0,0 +1,152 @@ +package com.jstarcraft.rns.data.converter; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.StringReader; +import java.util.Collection; +import java.util.Map.Entry; + +import javax.xml.parsers.SAXParser; +import javax.xml.parsers.SAXParserFactory; + +import org.xml.sax.Attributes; +import org.xml.sax.InputSource; +import org.xml.sax.SAXException; +import org.xml.sax.XMLReader; +import org.xml.sax.helpers.DefaultHandler; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.attribute.QualityAttribute; +import com.jstarcraft.ai.data.attribute.QuantityAttribute; +import com.jstarcraft.ai.data.converter.AbstractConverter; +import com.jstarcraft.ai.data.exception.DataException; +import com.jstarcraft.core.common.conversion.ConversionUtility; +import com.jstarcraft.core.utility.KeyValue; + +import it.unimi.dsi.fastutil.ints.Int2FloatRBTreeMap; +import it.unimi.dsi.fastutil.ints.Int2FloatSortedMap; +import it.unimi.dsi.fastutil.ints.Int2IntRBTreeMap; +import it.unimi.dsi.fastutil.ints.Int2IntSortedMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; + +public class YongfengZhangDatasetConverter extends AbstractConverter { + + class YongfengZhangInstanceHandler extends DefaultHandler { + + private DataModule module; + + private Int2IntSortedMap qualityFeatures = new Int2IntRBTreeMap(); + + private Int2FloatSortedMap quantityFeatures = new Int2FloatRBTreeMap(); + + private Int2ObjectOpenHashMap datas = new Int2ObjectOpenHashMap<>(); + + private StringBuffer buffer = new StringBuffer(); + + private int count; + + YongfengZhangInstanceHandler(DataModule module) { + this.module = module; + } + + private void parseData(BufferedReader buffer) throws IOException { + datas.clear(); + String line = buffer.readLine(); + line = buffer.readLine(); + String[] strings = line.split("\t"); + String user = strings[0]; + datas.put(0, user); + String item = strings[1]; + datas.put(1, item); + Float score = Float.valueOf(strings[3]); + datas.put(2, score); + line = buffer.readLine(); + line = buffer.readLine(); + strings = line.split("\t"); + for (String string : strings) { + string = string.substring(1, string.length() - 1); + String[] elements = string.split(", "); + String word = elements[0]; + int wordIndex = wordAttribute.convertData(word); + Float sentiment = elements[4].equalsIgnoreCase("N") ? Float.valueOf(elements[2]) : -Float.valueOf(elements[2]); + sentiment *= Float.valueOf(elements[3]); + datas.put(3 + wordIndex, sentiment); + } + + for (Int2ObjectMap.Entry data : datas.int2ObjectEntrySet()) { + int index = data.getIntKey(); + Object value = data.getValue(); + Entry> term = module.getOuterKeyValue(index); + KeyValue keyValue = term.getValue(); + if (keyValue.getValue()) { + QualityAttribute attribute = qualityAttributes.get(keyValue.getKey()); + value = ConversionUtility.convert(value, attribute.getType()); + int feature = attribute.convertData((Comparable) value); + qualityFeatures.put(module.getQualityInner(keyValue.getKey()) + index - term.getKey(), feature); + } else { + QuantityAttribute attribute = quantityAttributes.get(keyValue.getKey()); + value = ConversionUtility.convert(value, attribute.getType()); + float feature = attribute.convertData((Number) value); + quantityFeatures.put(module.getQuantityInner(keyValue.getKey()) + index - term.getKey(), feature); + } + } + module.associateInstance(qualityFeatures, quantityFeatures); + qualityFeatures.clear(); + quantityFeatures.clear(); + } + + @Override + public void startElement(String uri, String localName, String name, Attributes attributes) throws SAXException { + buffer.setLength(0); + } + + @Override + public void endElement(String uri, String localName, String name) throws SAXException { + try { + try (StringReader reader = new StringReader(buffer.toString()); BufferedReader stream = new BufferedReader(reader)) { + parseData(stream); + } + } catch (Exception exception) { + // TODO 处理日志. + throw new DataException(exception); + } + count++; + } + + @Override + public void characters(char characters[], int index, int length) throws SAXException { + buffer.append(characters, index, length); + } + + public int getCount() { + return count; + } + + } + + private QualityAttribute wordAttribute; + + public YongfengZhangDatasetConverter(QualityAttribute wordAttribute, Collection qualityAttributes, Collection quantityAttributes) { + super(qualityAttributes, quantityAttributes); + this.wordAttribute = wordAttribute; + } + + @Override + public int convert(DataModule module, InputStream iterator) { + try { + InputSource xmlSource = new InputSource(iterator); + SAXParserFactory saxFactory = SAXParserFactory.newInstance(); + SAXParser saxParser = saxFactory.newSAXParser(); + XMLReader sheetParser = saxParser.getXMLReader(); + YongfengZhangInstanceHandler handler = new YongfengZhangInstanceHandler(module); + sheetParser.setContentHandler(handler); + sheetParser.parse(xmlSource); + return handler.getCount(); + } catch (Exception exception) { + throw new RuntimeException(exception); + } + } + +} diff --git a/src/test/java/com/jstarcraft/rns/data/converter/YongfengZhangDatasetTestCase.java b/src/test/java/com/jstarcraft/rns/data/converter/YongfengZhangDatasetTestCase.java new file mode 100644 index 0000000..09c8735 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/converter/YongfengZhangDatasetTestCase.java @@ -0,0 +1,79 @@ +package com.jstarcraft.rns.data.converter; + +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; + +import javax.xml.parsers.SAXParser; +import javax.xml.parsers.SAXParserFactory; + +import org.junit.Assert; +import org.junit.Test; +import org.xml.sax.InputSource; +import org.xml.sax.XMLReader; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.attribute.QualityAttribute; +import com.jstarcraft.ai.data.converter.DataConverter; + +/** + * 演示如何自定义处理YongfengZhang数据集 + * + * @author Birdy + * + */ +public class YongfengZhangDatasetTestCase { + + @Test + public void testDataset() throws Exception { + File file = new File("data/labeled_DC/DC_feature_opinion/DC.txt"); + + // 定义数据空间 + Map> qualityDifinitions = new HashMap<>(); + qualityDifinitions.put("user", String.class); + qualityDifinitions.put("item", String.class); + qualityDifinitions.put("word", String.class); + Map> quantityDifinitions = new HashMap<>(); + quantityDifinitions.put("score", Float.class); + quantityDifinitions.put("sentiment", Float.class); + DataSpace dataSpace = new DataSpace(qualityDifinitions, quantityDifinitions); + + // 处理数据属性 + try (InputStream stream = new FileInputStream(file)) { + InputSource xmlSource = new InputSource(stream); + SAXParserFactory saxFactory = SAXParserFactory.newInstance(); + SAXParser saxParser = saxFactory.newSAXParser(); + XMLReader sheetParser = saxParser.getXMLReader(); + YongfengZhangAttributeHandler handler = new YongfengZhangAttributeHandler(dataSpace); + sheetParser.setContentHandler(handler); + sheetParser.parse(xmlSource); + } + QualityAttribute userAttribute = dataSpace.getQualityAttribute("user"); + QualityAttribute itemAttribute = dataSpace.getQualityAttribute("item"); + QualityAttribute wordAttribute = dataSpace.getQualityAttribute("word"); + Assert.assertEquals(89373, userAttribute.getSize()); + Assert.assertEquals(2397, itemAttribute.getSize()); + Assert.assertEquals(333, wordAttribute.getSize()); + + // 定义数据模块 + // 使用word属性大小作为sentiment特征维度 + TreeMap configuration = new TreeMap<>(); + configuration.put(1, "user"); + configuration.put(2, "item"); + configuration.put(3, "score"); + configuration.put(3 + wordAttribute.getSize(), "sentiment"); + DataModule dataModule = dataSpace.makeSparseModule("score", configuration, 1000000); + + // 处理数据实例 + DataConverter convertor = new YongfengZhangDatasetConverter(wordAttribute, dataSpace.getQualityAttributes(), dataSpace.getQuantityAttributes()); + try (InputStream stream = new FileInputStream(file)) { + convertor.convert(dataModule, stream); + } + Assert.assertEquals(123732, dataModule.getSize()); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/data/separator/DataSeparator.java b/src/test/java/com/jstarcraft/rns/data/separator/DataSeparator.java new file mode 100644 index 0000000..61ff916 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/separator/DataSeparator.java @@ -0,0 +1,40 @@ +package com.jstarcraft.rns.data.separator; + +import com.jstarcraft.ai.data.module.ReferenceModule; + +/** + * 数据分割器 + * + *
+ * 数据模块分割为训练模块与测试模块.
+ * 
+ * + * @author Birdy + * + */ +public interface DataSeparator { + + /** + * 获取分割数量 + * + * @return + */ + int getSize(); + + /** + * 获取训练引用 + * + * @param index + * @return + */ + ReferenceModule getTrainReference(int index); + + /** + * 获取测试引用 + * + * @param index + * @return + */ + ReferenceModule getTestReference(int index); + +} diff --git a/src/test/java/com/jstarcraft/rns/data/separator/GivenDataSeparator.java b/src/test/java/com/jstarcraft/rns/data/separator/GivenDataSeparator.java new file mode 100644 index 0000000..f6c29f1 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/separator/GivenDataSeparator.java @@ -0,0 +1,50 @@ +package com.jstarcraft.rns.data.separator; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.IntegerArray; +import com.jstarcraft.ai.data.module.ReferenceModule; + +/** + * 指定数据分割器 + * + * @author Birdy + * + */ +public class GivenDataSeparator implements DataSeparator { + + private DataModule dataModule; + + private IntegerArray trainReference; + + private IntegerArray testReference; + + public GivenDataSeparator(DataModule dataModule, int threshold) { + this.dataModule = dataModule; + this.trainReference = new IntegerArray(); + this.testReference = new IntegerArray(); + int size = dataModule.getSize(); + for (int index = 0; index < size; index++) { + if (index < threshold) { + this.trainReference.associateData(index); + } else { + this.testReference.associateData(index); + } + } + } + + @Override + public int getSize() { + return 1; + } + + @Override + public ReferenceModule getTrainReference(int index) { + return new ReferenceModule(trainReference, dataModule); + } + + @Override + public ReferenceModule getTestReference(int index) { + return new ReferenceModule(testReference, dataModule); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/data/separator/GivenInstanceSeparator.java b/src/test/java/com/jstarcraft/rns/data/separator/GivenInstanceSeparator.java new file mode 100644 index 0000000..3d1d37b --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/separator/GivenInstanceSeparator.java @@ -0,0 +1,53 @@ +package com.jstarcraft.rns.data.separator; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.IntegerArray; +import com.jstarcraft.ai.data.module.ReferenceModule; +import com.jstarcraft.ai.data.processor.DataSelector; + +/** + * 指定实例分割器 + * + * @author Birdy + * + */ +public class GivenInstanceSeparator implements DataSeparator { + + private DataModule dataModule; + + private IntegerArray trainReference; + + private IntegerArray testReference; + + public GivenInstanceSeparator(DataModule dataModule, DataSelector selector) { + this.dataModule = dataModule; + + this.trainReference = new IntegerArray(); + this.testReference = new IntegerArray(); + int position = 0; + for (DataInstance instance : dataModule) { + if (selector.select(instance)) { + this.testReference.associateData(position++); + } else { + this.trainReference.associateData(position++); + } + } + } + + @Override + public int getSize() { + return 1; + } + + @Override + public ReferenceModule getTrainReference(int index) { + return new ReferenceModule(trainReference, dataModule); + } + + @Override + public ReferenceModule getTestReference(int index) { + return new ReferenceModule(testReference, dataModule); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/data/separator/GivenNumberSeparator.java b/src/test/java/com/jstarcraft/rns/data/separator/GivenNumberSeparator.java new file mode 100644 index 0000000..6f84de5 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/separator/GivenNumberSeparator.java @@ -0,0 +1,86 @@ +package com.jstarcraft.rns.data.separator; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.IntegerArray; +import com.jstarcraft.ai.data.module.ReferenceModule; +import com.jstarcraft.ai.data.processor.DataSorter; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSorter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; +import com.jstarcraft.rns.data.processor.QuantityFeatureDataSorter; +import com.jstarcraft.rns.data.processor.RandomDataSorter; + +/** + * 指定数量分割器 + * + * @author Birdy + * + */ +public class GivenNumberSeparator implements DataSeparator { + + private DataModule dataModule; + + private IntegerArray trainReference; + + private IntegerArray testReference; + + public GivenNumberSeparator(DataSpace space, DataModule dataModule, String matchField, String sortField, int number) { + this.dataModule = dataModule; + ReferenceModule[] modules; + if (matchField == null) { + modules = new ReferenceModule[] { new ReferenceModule(dataModule) }; + } else { + int matchDimension = dataModule.getQualityInner(matchField); + DataSplitter splitter = new QualityFeatureDataSplitter(matchDimension); + int size = space.getQualityAttribute(matchField).getSize(); + modules = splitter.split(dataModule, size); + } + DataSorter sorter; + if (dataModule.getQualityInner(sortField) >= 0) { + int sortDimension = dataModule.getQualityInner(sortField); + sorter = new QualityFeatureDataSorter(sortDimension); + } else if (dataModule.getQuantityInner(sortField) >= 0) { + int sortDimension = dataModule.getQualityInner(sortField); + sorter = new QuantityFeatureDataSorter(sortDimension); + } else { + sorter = new RandomDataSorter(); + } + for (int index = 0, size = modules.length; index < size; index++) { + IntegerArray oldReference = modules[index].getReference(); + IntegerArray newReference = sorter.sort(modules[index]).getReference(); + for (int cursor = 0, length = newReference.getSize(); cursor < length; cursor++) { + newReference.setData(cursor, oldReference.getData(newReference.getData(cursor))); + } + modules[index] = new ReferenceModule(newReference, dataModule); + } + this.trainReference = new IntegerArray(); + this.testReference = new IntegerArray(); + for (ReferenceModule module : modules) { + IntegerArray reference = module.getReference(); + for (int cursor = 0, length = reference.getSize(); cursor < length; cursor++) { + if (cursor < number) { + this.trainReference.associateData(reference.getData(cursor)); + } else { + this.testReference.associateData(reference.getData(cursor)); + } + } + } + } + + @Override + public int getSize() { + return 1; + } + + @Override + public ReferenceModule getTrainReference(int index) { + return new ReferenceModule(trainReference, dataModule); + } + + @Override + public ReferenceModule getTestReference(int index) { + return new ReferenceModule(testReference, dataModule); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/data/separator/KFoldCrossValidationSeparator.java b/src/test/java/com/jstarcraft/rns/data/separator/KFoldCrossValidationSeparator.java new file mode 100644 index 0000000..10ccd8c --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/separator/KFoldCrossValidationSeparator.java @@ -0,0 +1,59 @@ +package com.jstarcraft.rns.data.separator; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.IntegerArray; +import com.jstarcraft.ai.data.module.ReferenceModule; +import com.jstarcraft.core.utility.RandomUtility; + +/** + * K折叠交叉分割器 + * + * @author Birdy + */ +public class KFoldCrossValidationSeparator implements DataSeparator { + + private DataModule dataModule; + + private Integer[] folds; + + private int number; + + public KFoldCrossValidationSeparator(DataModule dataModule, int number) { + this.dataModule = dataModule; + this.number = number; + this.folds = new Integer[this.dataModule.getSize()]; + for (int index = 0, size = this.folds.length; index < size; index++) { + this.folds[index] = index % number; + } + // 通过随机与交换的方式实现打乱排序的目的. + RandomUtility.shuffle(this.folds); + } + + @Override + public int getSize() { + return number; + } + + @Override + public ReferenceModule getTrainReference(int index) { + IntegerArray reference = new IntegerArray(); + for (int position = 0, size = dataModule.getSize(); position < size; position++) { + if (folds[position] != index) { + reference.associateData(position); + } + } + return new ReferenceModule(reference, dataModule); + } + + @Override + public ReferenceModule getTestReference(int index) { + IntegerArray reference = new IntegerArray(); + for (int position = 0, size = dataModule.getSize(); position < size; position++) { + if (folds[position] == index) { + reference.associateData(position); + } + } + return new ReferenceModule(reference, dataModule); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/data/separator/LeaveOneCrossValidationSeparator.java b/src/test/java/com/jstarcraft/rns/data/separator/LeaveOneCrossValidationSeparator.java new file mode 100644 index 0000000..f90a517 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/separator/LeaveOneCrossValidationSeparator.java @@ -0,0 +1,86 @@ +package com.jstarcraft.rns.data.separator; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.IntegerArray; +import com.jstarcraft.ai.data.module.ReferenceModule; +import com.jstarcraft.ai.data.processor.DataSorter; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSorter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; +import com.jstarcraft.rns.data.processor.QuantityFeatureDataSorter; +import com.jstarcraft.rns.data.processor.RandomDataSorter; + +/** + * 留一验证分割器 + * + * @author Bridy + * + */ +public class LeaveOneCrossValidationSeparator implements DataSeparator { + + private DataModule dataModule; + + private IntegerArray trainReference; + + private IntegerArray testReference; + + public LeaveOneCrossValidationSeparator(DataSpace space, DataModule dataModule, String matchField, String sortField) { + this.dataModule = dataModule; + ReferenceModule[] modules; + if (matchField == null) { + modules = new ReferenceModule[] { new ReferenceModule(dataModule) }; + } else { + int matchDimension = dataModule.getQualityInner(matchField); + DataSplitter splitter = new QualityFeatureDataSplitter(matchDimension); + int size = space.getQualityAttribute(matchField).getSize(); + modules = splitter.split(dataModule, size); + } + DataSorter sorter; + if (dataModule.getQualityInner(sortField) >= 0) { + int sortDimension = dataModule.getQualityInner(sortField); + sorter = new QualityFeatureDataSorter(sortDimension); + } else if (dataModule.getQuantityInner(sortField) >= 0) { + int sortDimension = dataModule.getQualityInner(sortField); + sorter = new QuantityFeatureDataSorter(sortDimension); + } else { + sorter = new RandomDataSorter(); + } + for (int index = 0, size = modules.length; index < size; index++) { + IntegerArray oldReference = modules[index].getReference(); + IntegerArray newReference = sorter.sort(modules[index]).getReference(); + for (int cursor = 0, length = newReference.getSize(); cursor < length; cursor++) { + newReference.setData(cursor, oldReference.getData(newReference.getData(cursor))); + } + modules[index] = new ReferenceModule(newReference, dataModule); + } + this.trainReference = new IntegerArray(); + this.testReference = new IntegerArray(); + for (ReferenceModule module : modules) { + IntegerArray reference = module.getReference(); + for (int cursor = 0, length = reference.getSize(); cursor < length; cursor++) { + if (length - cursor == 1) { + this.testReference.associateData(reference.getData(cursor)); + } else { + this.trainReference.associateData(reference.getData(cursor)); + } + } + } + } + + @Override + public int getSize() { + return 1; + } + + @Override + public ReferenceModule getTrainReference(int index) { + return new ReferenceModule(trainReference, dataModule); + } + + @Override + public ReferenceModule getTestReference(int index) { + return new ReferenceModule(testReference, dataModule); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/data/separator/RandomSeparator.java b/src/test/java/com/jstarcraft/rns/data/separator/RandomSeparator.java new file mode 100644 index 0000000..2a8e3f0 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/separator/RandomSeparator.java @@ -0,0 +1,65 @@ +package com.jstarcraft.rns.data.separator; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.IntegerArray; +import com.jstarcraft.ai.data.module.ReferenceModule; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; + +/** + * 随机分割器 + * + * @author Birdy + * + */ +public class RandomSeparator implements DataSeparator { + + private DataModule dataModule; + + private IntegerArray trainReference; + + private IntegerArray testReference; + + public RandomSeparator(DataSpace space, DataModule dataModule, String matchField, float random) { + this.dataModule = dataModule; + ReferenceModule[] modules; + if (matchField == null) { + modules = new ReferenceModule[] { new ReferenceModule(dataModule) }; + } else { + int matchDimension = dataModule.getQualityInner(matchField); + DataSplitter splitter = new QualityFeatureDataSplitter(matchDimension); + int size = space.getQualityAttribute(matchField).getSize(); + modules = splitter.split(dataModule, size); + } + this.trainReference = new IntegerArray(); + this.testReference = new IntegerArray(); + for (ReferenceModule module : modules) { + IntegerArray reference = module.getReference(); + for (int cursor = 0, length = reference.getSize(); cursor < length; cursor++) { + if (RandomUtility.randomFloat(1F) < random) { + this.trainReference.associateData(reference.getData(cursor)); + } else { + this.testReference.associateData(reference.getData(cursor)); + } + } + } + } + + @Override + public int getSize() { + return 1; + } + + @Override + public ReferenceModule getTrainReference(int index) { + return new ReferenceModule(trainReference, dataModule); + } + + @Override + public ReferenceModule getTestReference(int index) { + return new ReferenceModule(testReference, dataModule); + } + +} \ No newline at end of file diff --git a/src/test/java/com/jstarcraft/rns/data/separator/RatioSeparator.java b/src/test/java/com/jstarcraft/rns/data/separator/RatioSeparator.java new file mode 100644 index 0000000..86e2590 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/data/separator/RatioSeparator.java @@ -0,0 +1,88 @@ +package com.jstarcraft.rns.data.separator; + +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.IntegerArray; +import com.jstarcraft.ai.data.module.ReferenceModule; +import com.jstarcraft.ai.data.processor.DataSorter; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSorter; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; +import com.jstarcraft.rns.data.processor.QuantityFeatureDataSorter; +import com.jstarcraft.rns.data.processor.RandomDataSorter; + +/** + * 比率分割器 + * + * @author Birdy + * + */ +public class RatioSeparator implements DataSeparator { + + private DataModule dataModule; + + private IntegerArray trainReference; + + private IntegerArray testReference; + + public RatioSeparator(DataSpace space, DataModule dataModule, String matchField, String sortField, float ratio) { + this.dataModule = dataModule; + ReferenceModule[] modules; + if (matchField == null) { + modules = new ReferenceModule[] { new ReferenceModule(dataModule) }; + } else { + int matchDimension = dataModule.getQualityInner(matchField); + DataSplitter splitter = new QualityFeatureDataSplitter(matchDimension); + int size = space.getQualityAttribute(matchField).getSize(); + modules = splitter.split(dataModule, size); + } + DataSorter sorter; + if (dataModule.getQualityInner(sortField) >= 0) { + int sortDimension = dataModule.getQualityInner(sortField); + sorter = new QualityFeatureDataSorter(sortDimension); + } else if (dataModule.getQuantityInner(sortField) >= 0) { + int sortDimension = dataModule.getQualityInner(sortField); + sorter = new QuantityFeatureDataSorter(sortDimension); + } else { + sorter = new RandomDataSorter(); + } + for (int index = 0, size = modules.length; index < size; index++) { + IntegerArray oldReference = modules[index].getReference(); + IntegerArray newReference = sorter.sort(modules[index]).getReference(); + for (int cursor = 0, length = newReference.getSize(); cursor < length; cursor++) { + newReference.setData(cursor, oldReference.getData(newReference.getData(cursor))); + } + modules[index] = new ReferenceModule(newReference, dataModule); + } + this.trainReference = new IntegerArray(); + this.testReference = new IntegerArray(); + for (ReferenceModule module : modules) { + int count = 0; + int number = (int) ((module.getSize()) * ratio); + IntegerArray reference = module.getReference(); + for (int cursor = 0, length = reference.getSize(); cursor < length; cursor++) { + if (count++ < number) { + this.trainReference.associateData(reference.getData(cursor)); + } else { + this.testReference.associateData(reference.getData(cursor)); + } + } + } + } + + @Override + public int getSize() { + return 1; + } + + @Override + public ReferenceModule getTrainReference(int index) { + return new ReferenceModule(trainReference, dataModule); + } + + @Override + public ReferenceModule getTestReference(int index) { + return new ReferenceModule(testReference, dataModule); + } + +} \ No newline at end of file diff --git a/src/test/java/com/jstarcraft/rns/model/AutoRecLossFunctionTestCase.java b/src/test/java/com/jstarcraft/rns/model/AutoRecLossFunctionTestCase.java new file mode 100644 index 0000000..416345c --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/AutoRecLossFunctionTestCase.java @@ -0,0 +1,24 @@ +package com.jstarcraft.rns.model; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.lossfunctions.ILossFunction; + +import com.jstarcraft.ai.math.structure.matrix.Nd4jMatrix; +import com.jstarcraft.ai.model.neuralnetwork.activation.ActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.loss.LossFunction; +import com.jstarcraft.rns.model.collaborative.rating.AutoRecLearner; +import com.jstarcraft.rns.model.neuralnetwork.AutoRecLossFunction; + +public class AutoRecLossFunctionTestCase extends LossFunctionTestCase { + + @Override + protected ILossFunction getOldFunction(INDArray masks) { + return new AutoRecLearner(masks); + } + + @Override + protected LossFunction getNewFunction(INDArray masks, ActivationFunction function) { + return new AutoRecLossFunction(new Nd4jMatrix(masks)); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/LossFunctionTestCase.java b/src/test/java/com/jstarcraft/rns/model/LossFunctionTestCase.java new file mode 100644 index 0000000..dda8dc1 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/LossFunctionTestCase.java @@ -0,0 +1,107 @@ +package com.jstarcraft.rns.model; + +import java.util.LinkedList; +import java.util.concurrent.Future; + +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.ActivationSigmoid; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.lossfunctions.ILossFunction; + +import com.jstarcraft.ai.environment.EnvironmentContext; +import com.jstarcraft.ai.environment.EnvironmentFactory; +import com.jstarcraft.ai.math.MathUtility; +import com.jstarcraft.ai.math.structure.matrix.MathMatrix; +import com.jstarcraft.ai.math.structure.matrix.Nd4jMatrix; +import com.jstarcraft.ai.model.neuralnetwork.activation.ActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.activation.SigmoidActivationFunction; +import com.jstarcraft.ai.model.neuralnetwork.loss.LossFunction; +import com.jstarcraft.core.utility.KeyValue; + +public abstract class LossFunctionTestCase { + + protected static Nd4jMatrix getMatrix(INDArray array) { + return new Nd4jMatrix(array); + } + + protected static boolean equalMatrix(MathMatrix matrix, INDArray array) { + for (int row = 0; row < matrix.getRowSize(); row++) { + for (int column = 0; column < matrix.getColumnSize(); column++) { + if (Math.abs(matrix.getValue(row, column) - array.getFloat(row, column)) > MathUtility.EPSILON) { + return false; + } + } + } + return true; + } + + protected abstract ILossFunction getOldFunction(INDArray masks); + + protected abstract LossFunction getNewFunction(INDArray masks, ActivationFunction function); + + @Test + public void testScore() throws Exception { + EnvironmentContext context = EnvironmentFactory.getContext(); + Future task = context.doTask(() -> { + LinkedList> activetionList = new LinkedList<>(); + activetionList.add(new KeyValue<>(new ActivationSigmoid(), new SigmoidActivationFunction())); +// activetionList.add(new KeyValue<>(new ActivationSoftmax(), new SoftMaxActivationFunction())); + for (KeyValue keyValue : activetionList) { + INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(5, 2); + INDArray marks = Nd4j.create(new double[] { 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D }).reshape(5, 2); + ILossFunction oldFunction = getOldFunction(marks); + double value = oldFunction.computeScore(marks, array.dup(), keyValue.getKey(), null, false); + + Nd4jMatrix input = getMatrix(array.dup()); + Nd4jMatrix output = new Nd4jMatrix(Nd4j.zeros(input.getRowSize(), input.getColumnSize())); + ActivationFunction function = keyValue.getValue(); + function.forward(input, output); + LossFunction newFunction = getNewFunction(marks, function); + newFunction.doCache(getMatrix(marks), output); + double score = newFunction.computeScore(getMatrix(marks), output, null); + + System.out.println(value); + System.out.println(score); + + if (Math.abs(value - score) > MathUtility.EPSILON) { + Assert.fail(); + } + } + }); + task.get(); + } + + @Test + public void testGradient() throws Exception { + EnvironmentContext context = EnvironmentFactory.getContext(); + Future task = context.doTask(() -> { + LinkedList> activetionList = new LinkedList<>(); + activetionList.add(new KeyValue<>(new ActivationSigmoid(), new SigmoidActivationFunction())); +// activetionList.add(new KeyValue<>(new ActivationSoftmax(), new SoftMaxActivationFunction())); + for (KeyValue keyValue : activetionList) { + INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(5, 2); + INDArray marks = Nd4j.create(new double[] { 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D, 0D, 1D }).reshape(5, 2); + ILossFunction oldFunction = getOldFunction(marks); + INDArray value = oldFunction.computeGradient(marks, array.dup(), keyValue.getKey(), null); + + Nd4jMatrix input = getMatrix(array.dup()); + Nd4jMatrix output = new Nd4jMatrix(Nd4j.zeros(input.getRowSize(), input.getColumnSize())); + ActivationFunction function = keyValue.getValue(); + function.forward(input, output); + Nd4jMatrix gradient = new Nd4jMatrix(Nd4j.zeros(input.getRowSize(), input.getColumnSize())); + LossFunction newFunction = getNewFunction(marks, function); + newFunction.doCache(getMatrix(marks), output); + newFunction.computeGradient(getMatrix(marks), output, null, gradient); + function.backward(input, gradient, output); + System.out.println(value); + System.out.println(output); + Assert.assertTrue(equalMatrix(output, value)); + } + }); + task.get(); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/ModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/ModelTestCase.java new file mode 100644 index 0000000..0807d13 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/ModelTestCase.java @@ -0,0 +1,17 @@ +package com.jstarcraft.rns.model; + +import com.jstarcraft.ai.modem.ModemCodec; +import com.jstarcraft.rns.task.AbstractTask; + +public abstract class ModelTestCase { + + public void testModem(AbstractTask job) { + for (ModemCodec codec : ModemCodec.values()) { + Model oldModel = job.getModel(); + byte[] data = codec.encodeModel(oldModel); + Model newModel = (Model) codec.decodeModel(data); +// Assert.assertThat(newModel.predict(new int[] { 0, 1 }, new float[] {}), CoreMatchers.equalTo(oldModel.predict(new int[] { 0, 1 }, new float[] {}))); + } + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/ModelTestSuite.java b/src/test/java/com/jstarcraft/rns/model/ModelTestSuite.java new file mode 100644 index 0000000..0cc8b69 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/ModelTestSuite.java @@ -0,0 +1,27 @@ +package com.jstarcraft.rns.model; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +import com.jstarcraft.rns.model.benchmark.BenchmarkTestSuite; +import com.jstarcraft.rns.model.collaborative.CollaborativeTestSuite; +import com.jstarcraft.rns.model.content.ContentTestSuite; +import com.jstarcraft.rns.model.context.ContextTestSuite; +import com.jstarcraft.rns.model.extend.ExtendTestSuite; + +@RunWith(Suite.class) +@SuiteClasses({ + // 推荐器测试集(运行参数:-Xms1024M -Xmx8192M -ea) + BenchmarkTestSuite.class, + + CollaborativeTestSuite.class, + + ContentTestSuite.class, + + ContextTestSuite.class, + + ExtendTestSuite.class, }) +public class ModelTestSuite { + +} diff --git a/src/test/java/com/jstarcraft/rns/model/benchmark/BenchmarkTestSuite.java b/src/test/java/com/jstarcraft/rns/model/benchmark/BenchmarkTestSuite.java new file mode 100644 index 0000000..1c42976 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/benchmark/BenchmarkTestSuite.java @@ -0,0 +1,35 @@ +package com.jstarcraft.rns.model.benchmark; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +import com.jstarcraft.rns.model.benchmark.ranking.MostPopularModelTestCase; +import com.jstarcraft.rns.model.benchmark.rating.ConstantGuessModelTestCase; +import com.jstarcraft.rns.model.benchmark.rating.GlobalAverageModelTestCase; +import com.jstarcraft.rns.model.benchmark.rating.ItemAverageModelTestCase; +import com.jstarcraft.rns.model.benchmark.rating.ItemClusterModelTestCase; +import com.jstarcraft.rns.model.benchmark.rating.UserAverageModelTestCase; +import com.jstarcraft.rns.model.benchmark.rating.UserClusterModelTestCase; + +@RunWith(Suite.class) +@SuiteClasses({ + // 推荐器测试集 + ConstantGuessModelTestCase.class, + + GlobalAverageModelTestCase.class, + + ItemAverageModelTestCase.class, + + ItemClusterModelTestCase.class, + + MostPopularModelTestCase.class, + + RandomGuessModelTestCase.class, + + UserAverageModelTestCase.class, + + UserClusterModelTestCase.class }) +public class BenchmarkTestSuite { + +} diff --git a/src/test/java/com/jstarcraft/rns/model/benchmark/RandomGuessModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/benchmark/RandomGuessModelTestCase.java new file mode 100644 index 0000000..4090faf --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/benchmark/RandomGuessModelTestCase.java @@ -0,0 +1,58 @@ +package com.jstarcraft.rns.model.benchmark; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class RandomGuessModelTestCase { + + @Test + public void testRecommenderByRanking() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/benchmark/randomguess-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(RandomGuessModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.5192176F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.006268634F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.021699615F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.01120969F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(91.949F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.0055039763F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.012620986F, measures.getFloat(RecallEvaluator.class), 0F); + } + + @Test + public void testRecommenderByRating() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/benchmark/randomguess-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(RandomGuessModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(1.2862209F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9959667F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(2.479267F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/benchmark/ranking/MostPopularModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/benchmark/ranking/MostPopularModelTestCase.java new file mode 100644 index 0000000..86e64a6 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/benchmark/ranking/MostPopularModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.benchmark.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class MostPopularModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/benchmark/mostpopular-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(MostPopularModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.9207961F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.4124602F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.571964F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5158319F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(11.792954F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.33229586F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.62384576F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/benchmark/rating/ConstantGuessModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/ConstantGuessModelTestCase.java new file mode 100644 index 0000000..31fb1d3 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/ConstantGuessModelTestCase.java @@ -0,0 +1,34 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class ConstantGuessModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/benchmark/constantguess-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(ConstantGuessModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(1.0560759F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(1.0F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(1.4230907F, measures.getFloat(MSEEvaluator.class), 0F); + + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/benchmark/rating/GlobalAverageModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/GlobalAverageModelTestCase.java new file mode 100644 index 0000000..0a2d976 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/GlobalAverageModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class GlobalAverageModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/benchmark/globalaverage-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(GlobalAverageModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.71976596F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.77907884F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.85198635F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/benchmark/rating/ItemAverageModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/ItemAverageModelTestCase.java new file mode 100644 index 0000000..9282b6f --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/ItemAverageModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class ItemAverageModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/benchmark/itemaverage-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(ItemAverageModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.7296801F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.97241735F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.86413175F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/benchmark/rating/ItemClusterModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/ItemClusterModelTestCase.java new file mode 100644 index 0000000..7c1cddb --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/ItemClusterModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class ItemClusterModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/benchmark/itemcluster-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(ItemClusterModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.7197603F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.77907884F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.8519833F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/benchmark/rating/UserAverageModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/UserAverageModelTestCase.java new file mode 100644 index 0000000..621a8a7 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/UserAverageModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class UserAverageModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/benchmark/useraverage-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(UserAverageModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6461777F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.97241735F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.70172423F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/benchmark/rating/UserClusterModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/UserClusterModelTestCase.java new file mode 100644 index 0000000..64c24d6 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/benchmark/rating/UserClusterModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.benchmark.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class UserClusterModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/benchmark/usercluster-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(UserClusterModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.7197661F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.77907884F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.851986F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/CollaborativeTestSuite.java b/src/test/java/com/jstarcraft/rns/model/collaborative/CollaborativeTestSuite.java new file mode 100644 index 0000000..a1e2bf2 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/CollaborativeTestSuite.java @@ -0,0 +1,20 @@ +package com.jstarcraft.rns.model.collaborative; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +import com.jstarcraft.rns.model.collaborative.ranking.CollaborativeRankingTestSuite; +import com.jstarcraft.rns.model.collaborative.rating.CollaborativeRatingTestSuite; + +@RunWith(Suite.class) +@SuiteClasses({ + // 推荐器测试集 + // recommender.cf.ranking + CollaborativeRankingTestSuite.class, + + // recommender.cf.rating + CollaborativeRatingTestSuite.class, }) +public class CollaborativeTestSuite { + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/AoBPRModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/AoBPRModelTestCase.java new file mode 100644 index 0000000..0c08430 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/AoBPRModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class AoBPRModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/aobpr-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(AoBPRModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.8932367F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.38967326F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.53990144F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.48337516F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(21.130045F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.3229456F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.5686421F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/AspectModelRankingModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/AspectModelRankingModelTestCase.java new file mode 100644 index 0000000..7f3e66d --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/AspectModelRankingModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class AspectModelRankingModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/aspectmodelranking-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(AspectModelRankingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.8513018F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.15497543F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.42479676F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.2601204F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(37.362732F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.13302413F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.31291583F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/BHFreeRankingModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/BHFreeRankingModelTestCase.java new file mode 100644 index 0000000..7187404 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/BHFreeRankingModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class BHFreeRankingModelTestCase { + + @Test + public void testRecommenderRanking() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/bhfreeranking-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(BHFreeRankingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.9208008F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.413161F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5723087F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.51661724F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(11.79567F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.33276004F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.62499523F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/BPRModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/BPRModelTestCase.java new file mode 100644 index 0000000..8323191 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/BPRModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class BPRModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/bpr-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(BPRModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.89390284F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.39886427F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.54790056F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.49179932F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(21.467379F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.32268023F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.5762345F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/BUCMRankingModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/BUCMRankingModelTestCase.java new file mode 100644 index 0000000..1b0f861 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/BUCMRankingModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class BUCMRankingModelTestCase { + + @Test + public void testRecommenderRanking() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/bucmranking-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(BUCMRankingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.90781504F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.39793882F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5577617F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.49650663F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(13.080729F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.32407284F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.5914064F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/CDAEModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/CDAEModelTestCase.java new file mode 100644 index 0000000..e5ec84f --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/CDAEModelTestCase.java @@ -0,0 +1,42 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.neuralnetwork.CDAEModel; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class CDAEModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/cdae-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(CDAEModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.9188042F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.40759084F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5685547F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5108937F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(11.824657F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.3305053F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.61967427F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/CLiMFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/CLiMFModelTestCase.java new file mode 100644 index 0000000..45d76a4 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/CLiMFModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class CLiMFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/climf-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(CLiMFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.8829291F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.3739544F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5240671F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.4657167F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(19.389643F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.32049185F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.5460514F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/CollaborativeRankingTestSuite.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/CollaborativeRankingTestSuite.java new file mode 100644 index 0000000..eb0e523 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/CollaborativeRankingTestSuite.java @@ -0,0 +1,70 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +@RunWith(Suite.class) +@SuiteClasses({ + // 推荐器测试集 + // recommender.cf.ranking + AoBPRModelTestCase.class, + + AspectModelRankingModelTestCase.class, + + BHFreeRankingModelTestCase.class, + + BPRModelTestCase.class, + + BUCMRankingModelTestCase.class, + + CDAEModelTestCase.class, + + CLiMFModelTestCase.class, + +// DeepCrossTestCase.class, + + DeepFMModelTestCase.class, + + EALSModelTestCase.class, + + FISMAUCModelTestCase.class, + + FISMRMSEModelTestCase.class, + + GBPRModelTestCase.class, + + HMMModelTestCase.class, + + ItemBigramModelTestCase.class, + + ItemKNNRankingModelTestCase.class, + + LDAModelTestCase.class, + + LambdaFMModelTestCase.class, + + ListwiseMFModelTestCase.class, + + PLSAModelTestCase.class, + + RankALSModelTestCase.class, + + RankCDModelTestCase.class, + + RankSGDModelTestCase.class, + + RankVFCDModelTestCase.class, + + SLIMModelTestCase.class, + + UserKNNRankingModelTestCase.class, + + VBPRModelTestCase.class, + + WBPRModelTestCase.class, + + WRMFModelTestCase.class, }) +public class CollaborativeRankingTestSuite { + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/DeepCrossModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/DeepCrossModelTestCase.java new file mode 100644 index 0000000..209b6e3 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/DeepCrossModelTestCase.java @@ -0,0 +1,42 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.neuralnetwork.DeepCrossModel; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class DeepCrossModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/deepcross-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(DeepCrossModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.91646796F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.39583597F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5631354F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.50070804F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(12.073422F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.32440427F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.60906875F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/DeepFMModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/DeepFMModelTestCase.java new file mode 100644 index 0000000..93b84aa --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/DeepFMModelTestCase.java @@ -0,0 +1,42 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.neuralnetwork.DeepFMModel; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class DeepFMModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/deepfm-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(DeepFMModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.916794F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.4057996F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5699482F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.509845F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(11.902421F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.3271896F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.6142564F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/EALSModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/EALSModelTestCase.java new file mode 100644 index 0000000..15db459 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/EALSModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class EALSModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/eals-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(EALSModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.8613155F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.3126306F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.4568016F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.3947456F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(20.089636F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.27380753F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.46271476F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/FISMAUCModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/FISMAUCModelTestCase.java new file mode 100644 index 0000000..e744616 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/FISMAUCModelTestCase.java @@ -0,0 +1,42 @@ + +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class FISMAUCModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/fismauc-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(FISMAUCModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.91215646F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.40031773F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5572973F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5011414F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(12.074685F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.3284496F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.6029353F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/FISMRMSEModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/FISMRMSEModelTestCase.java new file mode 100644 index 0000000..64ae829 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/FISMRMSEModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class FISMRMSEModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/fismrmse-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(FISMRMSEModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.91482186F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.40795466F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5646953F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.509198F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(11.912338F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.33043906F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.6110687F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/GBPRModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/GBPRModelTestCase.java new file mode 100644 index 0000000..7b96df5 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/GBPRModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class GBPRModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/gbpr-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(GBPRModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.9211258F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.4100329F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.57143813F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5146437F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(11.876095F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.33090335F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.6251202F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/HMMModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/HMMModelTestCase.java new file mode 100644 index 0000000..a48cbba --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/HMMModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class HMMModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/game.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/hmm-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(HMMModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.8055914F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.18155931F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.37515923F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.25802785F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(16.010412F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.14572115F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.2281047F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/ItemBigramModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/ItemBigramModelTestCase.java new file mode 100644 index 0000000..7099160 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/ItemBigramModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class ItemBigramModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/itembigram-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(ItemBigramModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.8880696F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.33519757F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.46869788F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.42853972F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(17.111723F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.29191107F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.5330829F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/ItemKNNRankingModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/ItemKNNRankingModelTestCase.java new file mode 100644 index 0000000..5fcba9d --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/ItemKNNRankingModelTestCase.java @@ -0,0 +1,42 @@ + +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class ItemKNNRankingModelTestCase { + + @Test + public void testRecommenderRanking() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/itemknnranking-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(ItemKNNRankingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.87437975F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.3337493F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.4695067F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.4176727F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(20.234493F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.28581026F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.49248183F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/LDAModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/LDAModelTestCase.java new file mode 100644 index 0000000..a375345 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/LDAModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class LDAModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/lda-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(LDAModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.919801F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.41758165F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.58130056F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5200285F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(12.313484F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.33335692F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.62273633F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMModelTestCase.java new file mode 100644 index 0000000..ce8d0bb --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/LambdaFMModelTestCase.java @@ -0,0 +1,75 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class LambdaFMModelTestCase { + + @Test + public void testRecommenderByDynamic() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/game.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/lambdafmd-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(LambdaFMDynamicModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.8738025F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.27287653F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.43647555F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.34705582F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(13.505785F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.13822167F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.35131538F, measures.getFloat(RecallEvaluator.class), 0F); + } + + @Test + public void testRecommenderByStatic() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/game.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/lambdafms-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(LambdaFMStaticModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.87063825F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.27293852F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.43640044F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.34793553F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(16.4733F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.13940796F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.3569557F, measures.getFloat(RecallEvaluator.class), 0F); + } + + @Test + public void testRecommenderByWeight() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/game.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/lambdafmw-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(LambdaFMWeightModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.87338704F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.27333382F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.4372049F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.34727877F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(14.714127F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.13741651F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.35251862F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/ListwiseMFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/ListwiseMFModelTestCase.java new file mode 100644 index 0000000..0cdb0d6 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/ListwiseMFModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class ListwiseMFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/listwisemf-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(ListwiseMFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.9082031F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.40511125F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.56619364F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5052081F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(15.5366535F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.32944426F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.6009242F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/PLSAModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/PLSAModelTestCase.java new file mode 100644 index 0000000..e70f790 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/PLSAModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class PLSAModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/plsa-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(PLSAModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.8995006F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.41217065F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5718696F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5059695F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(16.010801F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.32400653F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.5855675F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankALSModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankALSModelTestCase.java new file mode 100644 index 0000000..35f3112 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankALSModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class RankALSModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/rankals-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(RankALSModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.8590121F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.2925542F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.51014286F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.38870648F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(25.271967F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.22931133F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.425085F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankCDModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankCDModelTestCase.java new file mode 100644 index 0000000..c82f6a6 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankCDModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class RankCDModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/product.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/rankcd-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(RankCDModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.5614361F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.012605725F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.046074353F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.026806315F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(55.43044F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.015423674F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.035282508F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankSGDModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankSGDModelTestCase.java new file mode 100644 index 0000000..092bf2e --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankSGDModelTestCase.java @@ -0,0 +1,42 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class RankSGDModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/ranksgd-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(RankSGDModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.8038758F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.23586644F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.42290422F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.3208106F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(42.833046F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.19363467F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.35374463F, measures.getFloat(RecallEvaluator.class), 0F); + + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankVFCDModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankVFCDModelTestCase.java new file mode 100644 index 0000000..e00bd87 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/RankVFCDModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class RankVFCDModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/product.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/rankvfcd-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(RankVFCDModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.5782429F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.019607447F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.06451471F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.038099557F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(61.21012F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.01949143F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.048476923F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/SLIMModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/SLIMModelTestCase.java new file mode 100644 index 0000000..ff1fb7f --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/SLIMModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class SLIMModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/slim-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(SLIMModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.91848505F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.44851264F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.6108289F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5455681F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(16.679905F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.34018686F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.6302096F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/UserKNNRankingModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/UserKNNRankingModelTestCase.java new file mode 100644 index 0000000..292ee2c --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/UserKNNRankingModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class UserKNNRankingModelTestCase { + + @Test + public void testRecommenderRanking() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/userknnranking-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(UserKNNRankingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.90752447F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.41615525F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.57524806F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.51393044F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(12.909212F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.32891354F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.601523F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/VBPRModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/VBPRModelTestCase.java new file mode 100644 index 0000000..049b3a0 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/VBPRModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class VBPRModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/product.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/vbpr-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(VBPRModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.54307634F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.009206047F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.035549607F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.018847404F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(45.42393F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.010338988F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.022509828F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/WARPMFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/WARPMFModelTestCase.java new file mode 100644 index 0000000..5852e22 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/WARPMFModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class WARPMFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/warpmf-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(WARPMFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.88621616F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.3896325F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.54730076F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.47975117F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(20.227654F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.31890017F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.5490907F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/WBPRModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/WBPRModelTestCase.java new file mode 100644 index 0000000..0aacdca --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/WBPRModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class WBPRModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/wbpr-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(WBPRModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.78071666F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.24647275F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.3337322F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.30442417F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(17.18609F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.25000066F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.35516423F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/WRMFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/WRMFModelTestCase.java new file mode 100644 index 0000000..eb63056 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/ranking/WRMFModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.collaborative.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class WRMFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/wrmf-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(WRMFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.90615845F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.43277714F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5828448F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5247985F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(15.179557F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.32917875F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.60779583F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/ASVDPlusPlusModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/ASVDPlusPlusModelTestCase.java new file mode 100644 index 0000000..4490a3d --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/ASVDPlusPlusModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class ASVDPlusPlusModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/asvdpp-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(ASVDPlusPlusModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.71975094F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.77920896F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.8519637F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/AspectModelRatingModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/AspectModelRatingModelTestCase.java new file mode 100644 index 0000000..2dc5889 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/AspectModelRatingModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class AspectModelRatingModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/aspectmodelrating-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(AspectModelRatingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.65754443F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.97918296F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.71809036F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/AutoRecModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/AutoRecModelTestCase.java new file mode 100644 index 0000000..afba872 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/AutoRecModelTestCase.java @@ -0,0 +1,34 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.neuralnetwork.AutoRecModel; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class AutoRecModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/autorec-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(AutoRecModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6861356F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.97801197F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.83574665F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BHFreeRatingModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BHFreeRatingModelTestCase.java new file mode 100644 index 0000000..a489b6f --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BHFreeRatingModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class BHFreeRatingModelTestCase { + + @Test + public void testRecommenderRating() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/bhfreerating-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(BHFreeRatingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.7197399F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.77907884F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.85197836F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BPMFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BPMFModelTestCase.java new file mode 100644 index 0000000..91b3e17 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BPMFModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class BPMFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/bpmf-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(BPMFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.665037F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9846474F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.70209664F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BUCMRatingModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BUCMRatingModelTestCase.java new file mode 100644 index 0000000..c141b64 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BUCMRatingModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class BUCMRatingModelTestCase { + + @Test + public void testRecommenderRating() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/bucmrating-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(BUCMRatingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.64833564F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.99102265F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6799151F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BiasedMFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BiasedMFModelTestCase.java new file mode 100644 index 0000000..11e1f0e --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/BiasedMFModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class BiasedMFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/biasedmf-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(BiasedMFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6315707F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.98386675F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.66219884F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/CCDModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/CCDModelTestCase.java new file mode 100644 index 0000000..8bb1dec --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/CCDModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class CCDModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/product.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/ccd-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(CCDModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.9645193F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9366471F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(1.6167855F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/CollaborativeRatingTestSuite.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/CollaborativeRatingTestSuite.java new file mode 100644 index 0000000..c0983ae --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/CollaborativeRatingTestSuite.java @@ -0,0 +1,58 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +@RunWith(Suite.class) +@SuiteClasses({ + // 推荐器测试集 + // recommender.cf.rating + AspectModelRatingModelTestCase.class, + + ASVDPlusPlusModelTestCase.class, + + BiasedMFModelTestCase.class, + + BHFreeRatingModelTestCase.class, + + BPMFModelTestCase.class, + + BUCMRatingModelTestCase.class, + + CCDModelTestCase.class, + + FFMModelTestCase.class, + + FMALSModelTestCase.class, + + FMSGDModelTestCase.class, + + GPLSAModelTestCase.class, + + IRRGModelTestCase.class, + + ItemKNNRatingModelTestCase.class, + + LDCCModelTestCase.class, + + LLORMAModelTestCase.class, + + MFALSModelTestCase.class, + + NMFModelTestCase.class, + + PMFModelTestCase.class, + + RBMModelTestCase.class, + + RFRecModelTestCase.class, + + SVDPlusPlusModelTestCase.class, + + URPModelTestCase.class, + + UserKNNRatingModelTestCase.class, }) +public class CollaborativeRatingTestSuite { + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/FFMModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/FFMModelTestCase.java new file mode 100644 index 0000000..2dab9e6 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/FFMModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class FFMModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/ffm-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(FFMModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.63446355F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.984127F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6668231F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/FMALSModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/FMALSModelTestCase.java new file mode 100644 index 0000000..acb2d5c --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/FMALSModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class FMALSModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/fmals-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(FMALSModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6478765F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.96031743F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.736361F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/FMSGDModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/FMSGDModelTestCase.java new file mode 100644 index 0000000..7d0326f --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/FMSGDModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class FMSGDModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/fmsgd-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(FMSGDModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6345172F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9842571F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6671004F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/GPLSAModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/GPLSAModelTestCase.java new file mode 100644 index 0000000..c65fc65 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/GPLSAModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class GPLSAModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/gplsa-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(GPLSAModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.67359F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.98907F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.79879F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/IRRGModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/IRRGModelTestCase.java new file mode 100644 index 0000000..8aab62a --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/IRRGModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class IRRGModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/irrg-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(IRRGModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6476635F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.98776996F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.7369969F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/ItemKNNRatingModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/ItemKNNRatingModelTestCase.java new file mode 100644 index 0000000..a0c79ab --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/ItemKNNRatingModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class ItemKNNRatingModelTestCase { + + @Test + public void testRecommenderRating() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/itemknnrating-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(ItemKNNRatingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6234117F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.95394224F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6731172F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/LDCCModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/LDCCModelTestCase.java new file mode 100644 index 0000000..bc86928 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/LDCCModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class LDCCModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/ldcc-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(LDCCModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6638286F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9928441F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.7066555F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/LLORMAModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/LLORMAModelTestCase.java new file mode 100644 index 0000000..b90a885 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/LLORMAModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class LLORMAModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/llorma-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(LLORMAModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6493022F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.96591204F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.76067156F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/MFALSModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/MFALSModelTestCase.java new file mode 100644 index 0000000..0000a89 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/MFALSModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class MFALSModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/mfals-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(MFALSModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.82939005F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9454853F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(1.3054749F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/NMFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/NMFModelTestCase.java new file mode 100644 index 0000000..6e89440 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/NMFModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class NMFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/nmf-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(NMFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6766053F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.96604216F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.83493185F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/PMFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/PMFModelTestCase.java new file mode 100644 index 0000000..1305fb8 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/PMFModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class PMFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/pmf-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(PMFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.729588F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.98165494F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.9994839F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/RBMModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/RBMModelTestCase.java new file mode 100644 index 0000000..96ccc8e --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/RBMModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class RBMModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/rbm-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(RBMModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.74484473F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.98503774F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.8896831F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/RFRecModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/RFRecModelTestCase.java new file mode 100644 index 0000000..1e44380 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/RFRecModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class RFRecModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/rfrec-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(RFRecModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.64007515F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9711163F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6939012F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/SVDPlusPlusModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/SVDPlusPlusModelTestCase.java new file mode 100644 index 0000000..01bb25b --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/SVDPlusPlusModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class SVDPlusPlusModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/svdpp-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(SVDPlusPlusModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6524793F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.99141294F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6828872F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/URPModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/URPModelTestCase.java new file mode 100644 index 0000000..b15988f --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/URPModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class URPModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/urp-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(URPModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.64206606F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9912829F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6712189F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/collaborative/rating/UserKNNRatingModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/UserKNNRatingModelTestCase.java new file mode 100644 index 0000000..7875786 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/collaborative/rating/UserKNNRatingModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.collaborative.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class UserKNNRatingModelTestCase { + + @Test + public void testRecommenderRating() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/userknnrating-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(UserKNNRatingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.63933104F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.94639605F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6927988F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/content/ContentTestSuite.java b/src/test/java/com/jstarcraft/rns/model/content/ContentTestSuite.java new file mode 100644 index 0000000..2d8220e --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/content/ContentTestSuite.java @@ -0,0 +1,26 @@ +package com.jstarcraft.rns.model.content; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +import com.jstarcraft.rns.model.content.ranking.TFIDFModelTestCase; +import com.jstarcraft.rns.model.content.rating.HFTModelTestCase; +import com.jstarcraft.rns.model.content.rating.TopicMFATModelTestCase; +import com.jstarcraft.rns.model.content.rating.TopicMFMTModelTestCase; + +@RunWith(Suite.class) +@SuiteClasses({ + // 推荐器测试集 + EFMModelTestCase.class, + + HFTModelTestCase.class, + + TFIDFModelTestCase.class, + + TopicMFATModelTestCase.class, + + TopicMFMTModelTestCase.class }) +public class ContentTestSuite { + +} diff --git a/src/test/java/com/jstarcraft/rns/model/content/EFMModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/content/EFMModelTestCase.java new file mode 100644 index 0000000..a87e27b --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/content/EFMModelTestCase.java @@ -0,0 +1,60 @@ +package com.jstarcraft.rns.model.content; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.model.content.ranking.EFMRankingModel; +import com.jstarcraft.rns.model.content.rating.EFMRatingModel; +import com.jstarcraft.rns.task.RankingTask; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class EFMModelTestCase { + + @Test + public void testRecommenderByRanking() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/dc_dense.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/content/efmranking-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(EFMRankingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6127146F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.01611203F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.04630792F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.040448334F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(53.2614F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.023869349F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.073571086F, measures.getFloat(RecallEvaluator.class), 0F); + } + + @Test + public void testRecommenderByRating() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/dc_dense.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/content/efmrating-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(EFMRatingModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6154602F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.8536428F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.78278536F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/content/ranking/TFIDFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/content/ranking/TFIDFModelTestCase.java new file mode 100644 index 0000000..0a6f220 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/content/ranking/TFIDFModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.content.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class TFIDFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/product.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/content/tfidf-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(TFIDFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.526974F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.0025012426F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.009865027F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.0074107912F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(66.74913F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.006129954F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.012177796F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/content/rating/HFTModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/content/rating/HFTModelTestCase.java new file mode 100644 index 0000000..415edf2 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/content/rating/HFTModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.content.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class HFTModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/musical_instruments.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/content/hft-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(HFTModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.64272016F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.94885534F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.8139318F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/content/rating/TopicMFATModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/content/rating/TopicMFATModelTestCase.java new file mode 100644 index 0000000..98c0527 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/content/rating/TopicMFATModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.content.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class TopicMFATModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/musical_instruments.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/content/topicmfat-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(TopicMFATModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.61896443F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9873356F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.7254535F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/content/rating/TopicMFMTModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/content/rating/TopicMFMTModelTestCase.java new file mode 100644 index 0000000..3e1d9b4 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/content/rating/TopicMFMTModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.content.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class TopicMFMTModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/musical_instruments.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/content/topicmfmt-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(TopicMFMTModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.61896443F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9873356F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.7254535F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/context/ContextTestSuite.java b/src/test/java/com/jstarcraft/rns/model/context/ContextTestSuite.java new file mode 100644 index 0000000..57981a8 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/context/ContextTestSuite.java @@ -0,0 +1,42 @@ +package com.jstarcraft.rns.model.context; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +import com.jstarcraft.rns.model.context.ranking.RankGeoFMModelTestCase; +import com.jstarcraft.rns.model.context.ranking.SBPRModelTestCase; +import com.jstarcraft.rns.model.context.rating.RSTEModelTestCase; +import com.jstarcraft.rns.model.context.rating.SoRecModelTestCase; +import com.jstarcraft.rns.model.context.rating.SoRegModelTestCase; +import com.jstarcraft.rns.model.context.rating.SocialMFModelTestCase; +import com.jstarcraft.rns.model.context.rating.TimeSVDModelTestCase; +import com.jstarcraft.rns.model.context.rating.TrustMFModelTestCase; +import com.jstarcraft.rns.model.context.rating.TrustSVDModelTestCase; + +@RunWith(Suite.class) +@SuiteClasses({ + // 推荐器测试集 + + // recommender.context.ranking + RankGeoFMModelTestCase.class, + + SBPRModelTestCase.class, + + // recommender.context.rating + RSTEModelTestCase.class, + + SocialMFModelTestCase.class, + + SoRecModelTestCase.class, + + SoRegModelTestCase.class, + + TimeSVDModelTestCase.class, + + TrustMFModelTestCase.class, + + TrustSVDModelTestCase.class }) +public class ContextTestSuite { + +} diff --git a/src/test/java/com/jstarcraft/rns/model/context/ranking/RankGeoFMModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/context/ranking/RankGeoFMModelTestCase.java new file mode 100644 index 0000000..66cf4fa --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/context/ranking/RankGeoFMModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.context.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class RankGeoFMModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/Foursquare.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/context/ranking/rankgeofm-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(RankGeoFMModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.7270785F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.054851912F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.2401193F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.110572465F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(37.500404F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.07865529F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.08640095F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/context/ranking/SBPRModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/context/ranking/SBPRModelTestCase.java new file mode 100644 index 0000000..6283b8f --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/context/ranking/SBPRModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.context.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class SBPRModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/context/ranking/sbpr-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(SBPRModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.91010016F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.41188803F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.56480426F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5072589F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(15.679053F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.32440445F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.59699106F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/context/rating/RSTEModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/context/rating/RSTEModelTestCase.java new file mode 100644 index 0000000..04644bc --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/context/rating/RSTEModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.context.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class RSTEModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/context/rating/rste-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(RSTEModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.64303476F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.99206346F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6777738F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/context/rating/SoRecModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/context/rating/SoRecModelTestCase.java new file mode 100644 index 0000000..30f181e --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/context/rating/SoRecModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.context.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class SoRecModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/context/rating/sorec-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(SoRecModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.64304614F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9923237F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6777599F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/context/rating/SoRegModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/context/rating/SoRegModelTestCase.java new file mode 100644 index 0000000..3f62782 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/context/rating/SoRegModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.context.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class SoRegModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/context/rating/soreg-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(SoRegModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6594284F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.9673432F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.72760236F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/context/rating/SocialMFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/context/rating/SocialMFModelTestCase.java new file mode 100644 index 0000000..7ba43c5 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/context/rating/SocialMFModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.context.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class SocialMFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/context/rating/socialmf-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(SocialMFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6466795F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.98881084F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6822788F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/context/rating/TimeSVDModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/context/rating/TimeSVDModelTestCase.java new file mode 100644 index 0000000..532f860 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/context/rating/TimeSVDModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.context.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class TimeSVDModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/context/rating/timesvd-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(TimeSVDModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6895415F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.93325526F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.8778289F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/context/rating/TrustMFModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/context/rating/TrustMFModelTestCase.java new file mode 100644 index 0000000..87fe6c6 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/context/rating/TrustMFModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.context.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class TrustMFModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/context/rating/trustmf-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(TrustMFModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6378669F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.98985165F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.69016904F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/context/rating/TrustSVDModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/context/rating/TrustSVDModelTestCase.java new file mode 100644 index 0000000..b29ef85 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/context/rating/TrustSVDModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.context.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class TrustSVDModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/context/rating/trustsvd-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(TrustSVDModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.61983573F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.98933125F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.6387536F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEModelTestCase.java new file mode 100644 index 0000000..2f9a547 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/dl4j/ranking/CDAEModelTestCase.java @@ -0,0 +1,43 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +// TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class CDAEModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/cdae-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(CDAEModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.9188042F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.40759084F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5685547F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5108937F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(11.824657F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.3305053F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.61967427F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMModelTestCase.java new file mode 100644 index 0000000..6c3ad7a --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/dl4j/ranking/DeepFMModelTestCase.java @@ -0,0 +1,43 @@ +package com.jstarcraft.rns.model.dl4j.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class DeepFMModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/ranking/deepfm-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(DeepFMModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.916794F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.4057996F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.5699482F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.509845F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(11.902421F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.3271896F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.6142564F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/dl4j/rating/AutoRecModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/dl4j/rating/AutoRecModelTestCase.java new file mode 100644 index 0000000..a5ea16a --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/dl4j/rating/AutoRecModelTestCase.java @@ -0,0 +1,35 @@ +package com.jstarcraft.rns.model.dl4j.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +//TODO 存档,以后需要基于DL4J重构. +@Deprecated +public class AutoRecModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/collaborative/rating/autorec-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(AutoRecModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6861356F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.97801197F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.83574665F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/extend/ExtendTestSuite.java b/src/test/java/com/jstarcraft/rns/model/extend/ExtendTestSuite.java new file mode 100644 index 0000000..02bdc39 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/extend/ExtendTestSuite.java @@ -0,0 +1,24 @@ +package com.jstarcraft.rns.model.extend; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +import com.jstarcraft.rns.model.extend.ranking.AssociationRuleModelTestCase; +import com.jstarcraft.rns.model.extend.ranking.PRankDModelTestCase; +import com.jstarcraft.rns.model.extend.rating.PersonalityDiagnosisModelTestCase; +import com.jstarcraft.rns.model.extend.rating.SlopeOneModelTestCase; + +@RunWith(Suite.class) +@SuiteClasses({ + // 推荐器测试集 + AssociationRuleModelTestCase.class, + + PersonalityDiagnosisModelTestCase.class, + + PRankDModelTestCase.class, + + SlopeOneModelTestCase.class }) +public class ExtendTestSuite { + +} diff --git a/src/test/java/com/jstarcraft/rns/model/extend/ranking/AssociationRuleModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/extend/ranking/AssociationRuleModelTestCase.java new file mode 100644 index 0000000..a88f0f2 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/extend/ranking/AssociationRuleModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.extend.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class AssociationRuleModelTestCase { + + @Test + public void testAssociationRuleRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/extend/associationrule-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(AssociationRuleModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.90853435F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.4180115F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.57776606F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.5162147F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(12.65794F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.33262724F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.6070039F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/extend/ranking/PRankDModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/extend/ranking/PRankDModelTestCase.java new file mode 100644 index 0000000..10db202 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/extend/ranking/PRankDModelTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.model.extend.ranking; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RankingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class PRankDModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/extend/prankd-test.properties")); + Option configuration = new MapOption(keyValues); + RankingTask job = new RankingTask(PRankDModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.7447238F, measures.getFloat(AUCEvaluator.class), 0F); + Assert.assertEquals(0.22893752F, measures.getFloat(MAPEvaluator.class), 0F); + Assert.assertEquals(0.32406133F, measures.getFloat(MRREvaluator.class), 0F); + Assert.assertEquals(0.283898F, measures.getFloat(NDCGEvaluator.class), 0F); + Assert.assertEquals(45.81069F, measures.getFloat(NoveltyEvaluator.class), 0F); + Assert.assertEquals(0.19436367F, measures.getFloat(PrecisionEvaluator.class), 0F); + Assert.assertEquals(0.32903677F, measures.getFloat(RecallEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/extend/rating/PersonalityDiagnosisModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/extend/rating/PersonalityDiagnosisModelTestCase.java new file mode 100644 index 0000000..3ce2bdd --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/extend/rating/PersonalityDiagnosisModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.extend.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class PersonalityDiagnosisModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/extend/personalitydiagnosis-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(PersonalityDiagnosisModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.7296383F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.7661983F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(1.0307052F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/model/extend/rating/SlopeOneModelTestCase.java b/src/test/java/com/jstarcraft/rns/model/extend/rating/SlopeOneModelTestCase.java new file mode 100644 index 0000000..61023f4 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/model/extend/rating/SlopeOneModelTestCase.java @@ -0,0 +1,33 @@ +package com.jstarcraft.rns.model.extend.rating; + +import java.util.Properties; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.rns.task.RatingTask; + +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +public class SlopeOneModelTestCase { + + @Test + public void testRecommender() throws Exception { + Properties keyValues = new Properties(); + keyValues.load(this.getClass().getResourceAsStream("/data/filmtrust.properties")); + keyValues.load(this.getClass().getResourceAsStream("/model/extend/slopeone-test.properties")); + Option configuration = new MapOption(keyValues); + RatingTask job = new RatingTask(SlopeOneModel.class, configuration); + Object2FloatSortedMap> measures = job.execute(); + Assert.assertEquals(0.6378848F, measures.getFloat(MAEEvaluator.class), 0F); + Assert.assertEquals(0.96174866F, measures.getFloat(MPEEvaluator.class), 0F); + Assert.assertEquals(0.7105687F, measures.getFloat(MSEEvaluator.class), 0F); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/script/ScriptTestCase.java b/src/test/java/com/jstarcraft/rns/script/ScriptTestCase.java new file mode 100644 index 0000000..a7aea9f --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/script/ScriptTestCase.java @@ -0,0 +1,254 @@ +package com.jstarcraft.rns.script; + +import java.io.File; +import java.util.Map; +import java.util.Properties; + +import org.apache.commons.io.FileUtils; +import org.junit.Assert; +import org.junit.Test; +import org.luaj.vm2.LuaTable; + +import com.jstarcraft.core.common.option.MapOption; +import com.jstarcraft.core.script.ScriptContext; +import com.jstarcraft.core.script.ScriptExpression; +import com.jstarcraft.core.script.ScriptScope; +import com.jstarcraft.core.script.groovy.GroovyExpression; +import com.jstarcraft.core.script.js.JsExpression; +import com.jstarcraft.core.script.kotlin.KotlinExpression; +import com.jstarcraft.core.script.lua.LuaExpression; +import com.jstarcraft.core.script.python.PythonExpression; +import com.jstarcraft.core.script.ruby.RubyExpression; +import com.jstarcraft.core.utility.StringUtility; + +public class ScriptTestCase { + + private static final ClassLoader loader = ScriptTestCase.class.getClassLoader(); + + /** + * 使用BeanShell脚本与JStarCraft框架交互 + * + * @throws Exception + */ + @Test + public void testBeanShell() throws Exception { + // 获取BeanShell脚本 + File file = new File(ScriptTestCase.class.getResource("Model.bsh").toURI()); + String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + + // 设置BeanShell脚本使用到的Java类 + ScriptContext context = new ScriptContext(); + context.useClasses(Properties.class, Assert.class); + context.useClass("Option", MapOption.class); + context.useClasses("com.jstarcraft.ai.evaluate"); + context.useClasses("com.jstarcraft.rns.task"); + context.useClasses("com.jstarcraft.rns.model.benchmark"); + // 设置BeanShell脚本使用到的Java变量 + ScriptScope scope = new ScriptScope(); + scope.createAttribute("loader", loader); + + // 执行BeanShell脚本 + ScriptExpression expression = new GroovyExpression(context, scope, script); + Map data = expression.doWith(Map.class); + Assert.assertEquals(0.005825241F, data.get("precision"), 0F); + Assert.assertEquals(0.011579763F, data.get("recall"), 0F); + Assert.assertEquals(1.2708743F, data.get("mae"), 0F); + Assert.assertEquals(2.425075F, data.get("mse"), 0F); + } + + /** + * 使用Groovy脚本与JStarCraft框架交互 + * + * @throws Exception + */ + @Test + public void testGroovy() throws Exception { + // 获取Groovy脚本 + File file = new File(ScriptTestCase.class.getResource("Model.groovy").toURI()); + String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + + // 设置Groovy脚本使用到的Java类 + ScriptContext context = new ScriptContext(); + context.useClasses(Properties.class, Assert.class); + context.useClass("Option", MapOption.class); + context.useClasses("com.jstarcraft.ai.evaluate"); + context.useClasses("com.jstarcraft.rns.task"); + context.useClasses("com.jstarcraft.rns.model.benchmark"); + // 设置Groovy脚本使用到的Java变量 + ScriptScope scope = new ScriptScope(); + scope.createAttribute("loader", loader); + + // 执行Groovy脚本 + ScriptExpression expression = new GroovyExpression(context, scope, script); + Map data = expression.doWith(Map.class); + Assert.assertEquals(0.005825241F, data.get("precision"), 0F); + Assert.assertEquals(0.011579763F, data.get("recall"), 0F); + Assert.assertEquals(1.2708743F, data.get("mae"), 0F); + Assert.assertEquals(2.425075F, data.get("mse"), 0F); + } + + /** + * 使用JS脚本与JStarCraft框架交互 + * + * @throws Exception + */ + @Test + public void testJs() throws Exception { + // 获取JS脚本 + File file = new File(ScriptTestCase.class.getResource("Model.js").toURI()); + String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + + // 设置JS脚本使用到的Java类 + ScriptContext context = new ScriptContext(); + context.useClasses(Properties.class, Assert.class); + context.useClass("Option", MapOption.class); + context.useClasses("com.jstarcraft.ai.evaluate"); + context.useClasses("com.jstarcraft.rns.task"); + context.useClasses("com.jstarcraft.rns.model.benchmark"); + // 设置JS脚本使用到的Java变量 + ScriptScope scope = new ScriptScope(); + scope.createAttribute("loader", loader); + + // 执行JS脚本 + ScriptExpression expression = new JsExpression(context, scope, script); + Map data = expression.doWith(Map.class); + Assert.assertEquals(0.005825241096317768F, data.get("precision"), 0F); + Assert.assertEquals(0.011579763144254684F, data.get("recall"), 0F); + Assert.assertEquals(1.270874261856079F, data.get("mae"), 0F); + Assert.assertEquals(2.425075054168701F, data.get("mse"), 0F); + } + + /** + * 使用Kotlin脚本与JStarCraft框架交互 + * + * @throws Exception + */ + @Test + public void testKotlin() throws Exception { + // 获取Kotlin脚本 + File file = new File(ScriptTestCase.class.getResource("Model.kt").toURI()); + String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + + // 设置Kotlin脚本使用到的Java类 + ScriptContext context = new ScriptContext(); + context.useClasses(Properties.class, Assert.class); + context.useClass("Option", MapOption.class); + context.useClasses("com.jstarcraft.ai.evaluate"); + context.useClasses("com.jstarcraft.rns.task"); + context.useClasses("com.jstarcraft.rns.model.benchmark"); + // 设置Kotlin脚本使用到的Java变量 + ScriptScope scope = new ScriptScope(); + scope.createAttribute("loader", loader); + + // 执行Kotlin脚本 + ScriptExpression expression = new KotlinExpression(context, scope, script); + Map data = expression.doWith(Map.class); + Assert.assertEquals(0.005825241096317768F, data.get("precision"), 0F); + Assert.assertEquals(0.011579763144254684F, data.get("recall"), 0F); + Assert.assertEquals(1.270874261856079F, data.get("mae"), 0F); + Assert.assertEquals(2.425075054168701F, data.get("mse"), 0F); + } + + /** + * 使用Lua脚本与JStarCraft框架交互 + * + *
+     * Java 11执行单元测试会抛Unable to make {member} accessible: module {A} does not '{operation} {package}' to {B}异常
+     * 是由于Java 9模块化导致
+     * 需要使用JVM参数:--add-exports java.base/jdk.internal.loader=ALL-UNNAMED
+     * 
+ * + * @throws Exception + */ + @Test + public void testLua() throws Exception { + // 获取Lua脚本 + File file = new File(ScriptTestCase.class.getResource("Model.lua").toURI()); + String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + + // 设置Lua脚本使用到的Java类 + ScriptContext context = new ScriptContext(); + context.useClasses(Properties.class, Assert.class); + context.useClass("Option", MapOption.class); + context.useClasses("com.jstarcraft.ai.evaluate"); + context.useClasses("com.jstarcraft.rns.task"); + context.useClasses("com.jstarcraft.rns.model.benchmark"); + // 设置Lua脚本使用到的Java变量 + ScriptScope scope = new ScriptScope(); + scope.createAttribute("loader", loader); + + // 执行Lua脚本 + ScriptExpression expression = new LuaExpression(context, scope, script); + LuaTable data = expression.doWith(LuaTable.class); + Assert.assertEquals(0.005825241F, data.get("precision").tofloat(), 0F); + Assert.assertEquals(0.011579763F, data.get("recall").tofloat(), 0F); + Assert.assertEquals(1.2708743F, data.get("mae").tofloat(), 0F); + Assert.assertEquals(2.425075F, data.get("mse").tofloat(), 0F); + } + + /** + * 使用Python脚本与JStarCraft框架交互 + * + * @throws Exception + */ + @Test + public void testPython() throws Exception { + // 设置Python环境变量 + System.setProperty("python.console.encoding", StringUtility.CHARSET.name()); + + // 获取Python脚本 + File file = new File(ScriptTestCase.class.getResource("Model.py").toURI()); + String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + + // 设置Python脚本使用到的Java类 + ScriptContext context = new ScriptContext(); + context.useClasses(Properties.class, Assert.class); + context.useClass("Option", MapOption.class); + context.useClasses("com.jstarcraft.ai.evaluate"); + context.useClasses("com.jstarcraft.rns.task"); + context.useClasses("com.jstarcraft.rns.model.benchmark"); + // 设置Python脚本使用到的Java变量 + ScriptScope scope = new ScriptScope(); + scope.createAttribute("loader", loader); + + // 执行Python脚本 + ScriptExpression expression = new PythonExpression(context, scope, script); + Map data = expression.doWith(Map.class); + Assert.assertEquals(0.005825241096317768D, data.get("precision"), 0D); + Assert.assertEquals(0.011579763144254684D, data.get("recall"), 0D); + Assert.assertEquals(1.270874261856079D, data.get("mae"), 0D); + Assert.assertEquals(2.425075054168701D, data.get("mse"), 0D); + } + + /** + * 使用Ruby脚本与JStarCraft框架交互 + * + * @throws Exception + */ + @Test + public void testRuby() throws Exception { + // 获取Ruby脚本 + File file = new File(ScriptTestCase.class.getResource("Model.rb").toURI()); + String script = FileUtils.readFileToString(file, StringUtility.CHARSET); + + // 设置Ruby脚本使用到的Java类 + ScriptContext context = new ScriptContext(); + context.useClasses(Properties.class, Assert.class); + context.useClass("Option", MapOption.class); + context.useClasses("com.jstarcraft.ai.evaluate"); + context.useClasses("com.jstarcraft.rns.task"); + context.useClasses("com.jstarcraft.rns.model.benchmark"); + // 设置Ruby脚本使用到的Java变量 + ScriptScope scope = new ScriptScope(); + scope.createAttribute("loader", loader); + + // 执行Ruby脚本 + ScriptExpression expression = new RubyExpression(context, scope, script); + Map data = expression.doWith(Map.class); + Assert.assertEquals(0.005825241096317768D, data.get("precision"), 0D); + Assert.assertEquals(0.011579763144254684D, data.get("recall"), 0D); + Assert.assertEquals(1.270874261856079D, data.get("mae"), 0D); + Assert.assertEquals(2.425075054168701D, data.get("mse"), 0D); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/task/AbstractTask.java b/src/test/java/com/jstarcraft/rns/task/AbstractTask.java new file mode 100644 index 0000000..c901d43 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/task/AbstractTask.java @@ -0,0 +1,334 @@ +package com.jstarcraft.rns.task; + +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.lang.reflect.Type; +import java.text.DecimalFormat; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import org.apache.commons.csv.CSVFormat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.DataModule; +import com.jstarcraft.ai.data.DataSpace; +import com.jstarcraft.ai.data.converter.ArffConverter; +import com.jstarcraft.ai.data.converter.CsvConverter; +import com.jstarcraft.ai.data.converter.DataConverter; +import com.jstarcraft.ai.data.module.ReferenceModule; +import com.jstarcraft.ai.data.processor.DataSplitter; +import com.jstarcraft.ai.environment.EnvironmentContext; +import com.jstarcraft.ai.environment.EnvironmentFactory; +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.core.common.conversion.json.JsonUtility; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.common.reflection.ReflectionUtility; +import com.jstarcraft.core.common.reflection.TypeUtility; +import com.jstarcraft.core.utility.Integer2FloatKeyValue; +import com.jstarcraft.core.utility.KeyValue; +import com.jstarcraft.core.utility.RandomUtility; +import com.jstarcraft.core.utility.StringUtility; +import com.jstarcraft.rns.data.processor.QualityFeatureDataSplitter; +import com.jstarcraft.rns.data.separator.DataSeparator; +import com.jstarcraft.rns.data.separator.GivenDataSeparator; +import com.jstarcraft.rns.data.separator.GivenNumberSeparator; +import com.jstarcraft.rns.data.separator.KFoldCrossValidationSeparator; +import com.jstarcraft.rns.data.separator.LeaveOneCrossValidationSeparator; +import com.jstarcraft.rns.data.separator.RandomSeparator; +import com.jstarcraft.rns.data.separator.RatioSeparator; +import com.jstarcraft.rns.model.Model; +import com.jstarcraft.rns.model.exception.ModelException; + +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; +import it.unimi.dsi.fastutil.objects.Object2FloatMap; +import it.unimi.dsi.fastutil.objects.Object2FloatRBTreeMap; +import it.unimi.dsi.fastutil.objects.Object2FloatSortedMap; + +/** + * 抽象任务 + * + * @author Birdy + * + * @param + */ +public abstract class AbstractTask { + + protected final Logger logger = LoggerFactory.getLogger(this.getClass()); + + protected Option configurator; + + protected String userField, itemField, scoreField; + + protected int userDimension, itemDimension, userSize, itemSize; + + protected ReferenceModule[] trainModules, testModules; + + protected DataModule dataModule, trainMarker, testMarker; + + protected Model model; + + protected AbstractTask(Model model, Option configurator) { + this.configurator = configurator; + Long seed = configurator.getLong("recommender.random.seed"); + if (seed != null) { + RandomUtility.setSeed(seed); + } + this.model = model; + } + + protected AbstractTask(Class clazz, Option configurator) { + this.configurator = configurator; + Long seed = configurator.getLong("recommender.random.seed"); + if (seed != null) { + RandomUtility.setSeed(seed); + } + this.model = (Model) ReflectionUtility.getInstance(clazz); + } + + protected abstract Collection getEvaluators(SparseMatrix featureMatrix); + + protected abstract L check(int userIndex); + + protected abstract R recommend(Model recommender, int userIndex); + + private ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + + private Map, Integer2FloatKeyValue> evaluate(Collection evaluators, Model recommender) { + Map, Integer2FloatKeyValue[]> values = new HashMap<>(); + for (Evaluator evaluator : evaluators) { + values.put(evaluator.getClass(), new Integer2FloatKeyValue[userSize]); + } + // 按照用户切割任务. + CountDownLatch latch = new CountDownLatch(userSize); + for (int userIndex = 0; userIndex < userSize; userIndex++) { + int index = userIndex; + executor.submit(() -> { + try { + ReferenceModule module = testModules[index]; + if (module.getSize() == 0) { + return; + } + // 校验集合 + L checkCollection = check(index); + // 推荐列表 + R recommendList = recommend(recommender, index); + // 测量列表 + for (Evaluator evaluator : evaluators) { + Integer2FloatKeyValue[] measures = values.get(evaluator.getClass()); + Integer2FloatKeyValue measure = evaluator.evaluate(checkCollection, recommendList); + measures[index] = measure; + } + } catch (Exception exception) { + logger.error("任务异常", exception); + } finally { + latch.countDown(); + } + }); + } + try { + latch.await(); + } catch (Exception exception) { + throw new ModelException(exception); + } + + Map, Integer2FloatKeyValue> measures = new HashMap<>(); + for (Entry, Integer2FloatKeyValue[]> term : values.entrySet()) { + Integer2FloatKeyValue measure = new Integer2FloatKeyValue(0, 0F); + for (Integer2FloatKeyValue element : term.getValue()) { + if (element == null) { + continue; + } + measure.setKey(measure.getKey() + element.getKey()); + measure.setValue(measure.getValue() + element.getValue()); + } + measures.put(term.getKey(), measure); + } + return measures; + } + + public Object2FloatSortedMap> execute() throws Exception { + userField = configurator.getString("data.model.fields.user", "user"); + itemField = configurator.getString("data.model.fields.item", "item"); + scoreField = configurator.getString("data.model.fields.score", "score"); + + // TODO 数据属性部分 + // 离散属性 + Type dicreteConfiguration = TypeUtility.parameterize(HashMap.class, String.class, Class.class); + Map> dicreteDifinitions = JsonUtility.string2Object(configurator.getString("data.attributes.dicrete"), dicreteConfiguration); + // 连续属性 + Type continuousConfiguration = TypeUtility.parameterize(HashMap.class, String.class, Class.class); + Map> continuousDifinitions = JsonUtility.string2Object(configurator.getString("data.attributes.continuous"), continuousConfiguration); + + // 数据空间部分 + DataSpace space = new DataSpace(dicreteDifinitions, continuousDifinitions); + + // TODO 数据模型部分 + ModuleConfigurer[] moduleConfigurers = JsonUtility.string2Object(configurator.getString("data.modules"), ModuleConfigurer[].class); + for (ModuleConfigurer moduleConfigurer : moduleConfigurers) { + space.makeDenseModule(moduleConfigurer.getName(), moduleConfigurer.getConfiguration(), 1000000000); + } + + // TODO 数据转换器部分 + Type convertorConfiguration = TypeUtility.parameterize(LinkedHashMap.class, String.class, TypeUtility.parameterize(KeyValue.class, String.class, HashMap.class)); + ConverterConfigurer[] converterConfigurers = JsonUtility.string2Object(configurator.getString("data.converters"), ConverterConfigurer[].class); + for (ConverterConfigurer converterConfigurer : converterConfigurers) { + String name = converterConfigurer.getName(); + String type = converterConfigurer.getType(); + String path = converterConfigurer.getPath(); + DataConverter convertor = null; + switch (type) { + case "arff": { + convertor = ReflectionUtility.getInstance(ArffConverter.class, space.getQualityAttributes(), space.getQuantityAttributes()); + break; + } + case "csv": { + CSVFormat format = CSVFormat.DEFAULT.withDelimiter(configurator.getCharacter("data.separator.delimiter", ' ')); + convertor = ReflectionUtility.getInstance(CsvConverter.class, format, space.getQualityAttributes(), space.getQuantityAttributes()); + break; + } + default: { + throw new ModelException("不支持的转换格式"); + } + } + File file = new File(path); + DataModule module = space.getModule(name); + try (InputStream stream = new FileInputStream(file)) { + convertor.convert(module, stream); + } + } + + // TODO 数据切割器部分 + SeparatorConfigurer separatorConfigurer = JsonUtility.string2Object(configurator.getString("data.separator"), SeparatorConfigurer.class); + DataModule module = space.getModule(separatorConfigurer.getName()); + int scoreDimension = module.getQuantityInner(scoreField); + for (DataInstance instance : module) { + // 将特征设置为标记 + instance.setQuantityMark(instance.getQuantityFeature(scoreDimension)); + } + DataSeparator separator; + switch (separatorConfigurer.getType()) { + case "kcv": { + int size = configurator.getInteger("data.separator.kcv.number", 1); + separator = new KFoldCrossValidationSeparator(module, size); + break; + } + case "loocv": { + separator = new LeaveOneCrossValidationSeparator(space, module, separatorConfigurer.getMatchField(), separatorConfigurer.getSortField()); + break; + } + case "testset": { + int threshold = configurator.getInteger("data.separator.threshold"); + separator = new GivenDataSeparator(module, threshold); + break; + } + case "givenn": { + int number = configurator.getInteger("data.separator.given-number.number"); + separator = new GivenNumberSeparator(space, module, separatorConfigurer.getMatchField(), separatorConfigurer.getSortField(), number); + break; + } + case "random": { + float random = configurator.getFloat("data.separator.random.value", 0.8F); + separator = new RandomSeparator(space, module, separatorConfigurer.getMatchField(), random); + break; + } + case "ratio": { + float ratio = configurator.getFloat("data.separator.ratio.value", 0.8F); + separator = new RatioSeparator(space, module, separatorConfigurer.getMatchField(), separatorConfigurer.getSortField(), ratio); + break; + } + default: { + throw new ModelException("不支持的划分类型"); + } + } + + // 评估部分 + Double binarize = configurator.getDouble("data.convert.binarize.threshold"); + Object2FloatSortedMap> measures = new Object2FloatRBTreeMap<>((left, right) -> { + return left.getName().compareTo(right.getName()); + }); + + EnvironmentContext context = EnvironmentFactory.getContext(); + StringBuffer message = new StringBuffer(); + Future task = context.doTask(() -> { + try { + for (int index = 0; index < separator.getSize(); index++) { + trainMarker = separator.getTrainReference(index); + testMarker = separator.getTestReference(index); + dataModule = module; + + userDimension = module.getQualityInner(userField); + itemDimension = module.getQualityInner(itemField); + userSize = space.getQualityAttribute(userField).getSize(); + itemSize = space.getQualityAttribute(itemField).getSize(); + + DataSplitter splitter = new QualityFeatureDataSplitter(userDimension); + trainModules = splitter.split(trainMarker, userSize); + testModules = splitter.split(testMarker, userSize); + + HashMatrix dataTable = new HashMatrix(true, userSize, itemSize, new Long2FloatRBTreeMap()); + for (DataInstance instance : dataModule) { + int rowIndex = instance.getQualityFeature(userDimension); + int columnIndex = instance.getQualityFeature(itemDimension); + // TODO 处理冲突 + dataTable.setValue(rowIndex, columnIndex, instance.getQuantityMark()); + } + SparseMatrix featureMatrix = SparseMatrix.valueOf(userSize, itemSize, dataTable); + message.append(StringUtility.format("| {} |", model.getClass().getSimpleName())); + { + long current = System.currentTimeMillis(); + model.prepare(configurator, trainMarker, space); + model.practice(); + message.append(StringUtility.format(" {} |", System.currentTimeMillis() - current)); + } + { + long current = System.currentTimeMillis(); + for (Entry, Integer2FloatKeyValue> measure : evaluate(getEvaluators(featureMatrix), model).entrySet()) { + float value = measure.getValue().getValue() / measure.getValue().getKey(); + measures.put(measure.getKey(), value); + } + message.append(StringUtility.format(" {} |", System.currentTimeMillis() - current)); + } + } + } catch (Exception exception) { + logger.error("任务异常", exception); + } + }); + task.get(); + + for (Object2FloatMap.Entry> term : measures.object2FloatEntrySet()) { + term.setValue(term.getFloatValue() / separator.getSize()); + if (logger.isDebugEnabled()) { + logger.debug(StringUtility.format("Assert.assertEquals({}F, measures.getFloat({}.class), 0F);", term.getFloatValue(), term.getKey().getSimpleName())); + } + message.append(StringUtility.format(" {} |", format.format(term.getFloatValue()))); + } + + if (logger.isInfoEnabled()) { + logger.info(message.toString()); + } + return measures; + } + + private DecimalFormat format = new DecimalFormat("####0.00000"); + + public Model getModel() { + return model; + } + + public DataModule getDataModule() { + return dataModule; + } + +} diff --git a/src/test/java/com/jstarcraft/rns/task/ConverterConfigurer.java b/src/test/java/com/jstarcraft/rns/task/ConverterConfigurer.java new file mode 100644 index 0000000..967875a --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/task/ConverterConfigurer.java @@ -0,0 +1,23 @@ +package com.jstarcraft.rns.task; + +public class ConverterConfigurer { + + private String name; + + private String type; + + private String path; + + public String getName() { + return name; + } + + public String getType() { + return type; + } + + public String getPath() { + return path; + } + +} diff --git a/src/test/java/com/jstarcraft/rns/task/ModuleConfigurer.java b/src/test/java/com/jstarcraft/rns/task/ModuleConfigurer.java new file mode 100644 index 0000000..0c5c27b --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/task/ModuleConfigurer.java @@ -0,0 +1,19 @@ +package com.jstarcraft.rns.task; + +import java.util.TreeMap; + +public class ModuleConfigurer { + + private String name; + + private TreeMap configuration; + + public String getName() { + return name; + } + + public TreeMap getConfiguration() { + return configuration; + } + +} diff --git a/src/test/java/com/jstarcraft/rns/task/RankingTask.java b/src/test/java/com/jstarcraft/rns/task/RankingTask.java new file mode 100644 index 0000000..a9461de --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/task/RankingTask.java @@ -0,0 +1,103 @@ +package com.jstarcraft.rns.task; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.module.ArrayInstance; +import com.jstarcraft.ai.data.module.ReferenceModule; +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.ranking.AUCEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MAPEvaluator; +import com.jstarcraft.ai.evaluate.ranking.MRREvaluator; +import com.jstarcraft.ai.evaluate.ranking.NDCGEvaluator; +import com.jstarcraft.ai.evaluate.ranking.NoveltyEvaluator; +import com.jstarcraft.ai.evaluate.ranking.PrecisionEvaluator; +import com.jstarcraft.ai.evaluate.ranking.RecallEvaluator; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.Integer2FloatKeyValue; +import com.jstarcraft.rns.model.Model; + +import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.ints.IntList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; + +/** + * 排序任务 + * + * @author Birdy + * + */ +public class RankingTask extends AbstractTask { + + public RankingTask(Model recommender, Option configuration) { + super(recommender, configuration); + } + + public RankingTask(Class clazz, Option configuration) { + super(clazz, configuration); + } + + @Override + protected Collection getEvaluators(SparseMatrix featureMatrix) { + Collection evaluators = new LinkedList<>(); + int size = configurator.getInteger("recommender.recommender.ranking.topn", 10); + evaluators.add(new AUCEvaluator(size)); + evaluators.add(new MAPEvaluator(size)); + evaluators.add(new MRREvaluator(size)); + evaluators.add(new NDCGEvaluator(size)); + evaluators.add(new NoveltyEvaluator(size, featureMatrix)); + evaluators.add(new PrecisionEvaluator(size)); + evaluators.add(new RecallEvaluator(size)); + return evaluators; + } + + @Override + protected IntSet check(int userIndex) { + ReferenceModule testModule = testModules[userIndex]; + IntSet itemSet = new IntOpenHashSet(); + for (DataInstance instance : testModule) { + itemSet.add(instance.getQualityFeature(itemDimension)); + } + return itemSet; + } + + @Override + protected IntList recommend(Model recommender, int userIndex) { + ReferenceModule trainModule = trainModules[userIndex]; + ReferenceModule testModule = testModules[userIndex]; + IntSet itemSet = new IntOpenHashSet(); + for (DataInstance instance : trainModule) { + itemSet.add(instance.getQualityFeature(itemDimension)); + } + // TODO 此处代码需要重构 + ArrayInstance copy = new ArrayInstance(trainMarker.getQualityOrder(), trainMarker.getQuantityOrder()); + copy.copyInstance(testModule.getInstance(0)); + copy.setQualityFeature(userDimension, userIndex); + + List rankList = new ArrayList<>(itemSize - itemSet.size()); + for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { + if (itemSet.contains(itemIndex)) { + continue; + } + copy.setQualityFeature(itemDimension, itemIndex); + recommender.predict(copy); + rankList.add(new Integer2FloatKeyValue(itemIndex, copy.getQuantityMark())); + } + Collections.sort(rankList, (left, right) -> { + return Float.compare(right.getValue(), left.getValue()); + }); + + IntList recommendList = new IntArrayList(rankList.size()); + for (Integer2FloatKeyValue keyValue : rankList) { + recommendList.add(keyValue.getKey()); + } + return recommendList; + } + +} diff --git a/src/test/java/com/jstarcraft/rns/task/RatingTask.java b/src/test/java/com/jstarcraft/rns/task/RatingTask.java new file mode 100644 index 0000000..50c81dc --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/task/RatingTask.java @@ -0,0 +1,78 @@ +package com.jstarcraft.rns.task; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +import com.jstarcraft.ai.data.DataInstance; +import com.jstarcraft.ai.data.module.ArrayInstance; +import com.jstarcraft.ai.data.module.ReferenceModule; +import com.jstarcraft.ai.evaluate.Evaluator; +import com.jstarcraft.ai.evaluate.rating.MAEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MPEEvaluator; +import com.jstarcraft.ai.evaluate.rating.MSEEvaluator; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; +import com.jstarcraft.core.common.option.Option; +import com.jstarcraft.core.utility.Integer2FloatKeyValue; +import com.jstarcraft.rns.model.Model; + +import it.unimi.dsi.fastutil.floats.FloatArrayList; +import it.unimi.dsi.fastutil.floats.FloatList; + +/** + * 评分任务 + * + * @author Birdy + * + */ +public class RatingTask extends AbstractTask { + + public RatingTask(Model recommender, Option configuration) { + super(recommender, configuration); + } + + public RatingTask(Class clazz, Option configuration) { + super(clazz, configuration); + } + + @Override + protected Collection getEvaluators(SparseMatrix featureMatrix) { + Collection evaluators = new LinkedList<>(); + float minimum = configurator.getFloat("recommender.recommender.rating.minimum", 0.5F); + float maximum = configurator.getFloat("recommender.recommender.rating.maximum", 4F); + evaluators.add(new MAEEvaluator(minimum, maximum)); + evaluators.add(new MPEEvaluator(minimum, maximum, 0.01F)); + evaluators.add(new MSEEvaluator(minimum, maximum)); + return evaluators; + } + + @Override + protected FloatList check(int userIndex) { + ReferenceModule testModule = testModules[userIndex]; + FloatList scoreList = new FloatArrayList(testModule.getSize()); + for (DataInstance instance : testModule) { + scoreList.add(instance.getQuantityMark()); + } + return scoreList; + } + + @Override + protected FloatList recommend(Model recommender, int userIndex) { + ReferenceModule testModule = testModules[userIndex]; + ArrayInstance copy = new ArrayInstance(testMarker.getQualityOrder(), testMarker.getQuantityOrder()); + List rateList = new ArrayList<>(testModule.getSize()); + for (DataInstance instance : testModule) { + copy.copyInstance(instance); + recommender.predict(copy); + rateList.add(new Integer2FloatKeyValue(copy.getQualityFeature(itemDimension), copy.getQuantityMark())); + } + + FloatList recommendList = new FloatArrayList(rateList.size()); + for (Integer2FloatKeyValue keyValue : rateList) { + recommendList.add(keyValue.getValue()); + } + return recommendList; + } + +} diff --git a/src/test/java/com/jstarcraft/rns/task/SeparatorConfigurer.java b/src/test/java/com/jstarcraft/rns/task/SeparatorConfigurer.java new file mode 100644 index 0000000..1546b0d --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/task/SeparatorConfigurer.java @@ -0,0 +1,29 @@ +package com.jstarcraft.rns.task; + +public class SeparatorConfigurer { + + private String name; + + private String type; + + private String matchField; + + private String sortField; + + public String getName() { + return name; + } + + public String getType() { + return type; + } + + public String getMatchField() { + return matchField; + } + + public String getSortField() { + return sortField; + } + +} diff --git a/src/test/java/com/jstarcraft/rns/utility/GammaUtilityTestCase.java b/src/test/java/com/jstarcraft/rns/utility/GammaUtilityTestCase.java new file mode 100644 index 0000000..70590b7 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/utility/GammaUtilityTestCase.java @@ -0,0 +1,111 @@ +package com.jstarcraft.rns.utility; + +import org.apache.commons.math3.special.Gamma; +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.math.MathUtility; + +import net.sourceforge.jdistlib.math.PolyGamma; + +public class GammaUtilityTestCase { + + @Test + public void testGamma() { + // logGamma遇到负数会变为NaN或者无穷. + Assert.assertTrue(Float.isNaN(GammaUtility.logGamma(0F)) == Double.isNaN(Gamma.logGamma(0D))); + Assert.assertTrue(Float.isNaN(GammaUtility.logGamma(-0F)) == Double.isNaN(Gamma.logGamma(-0D))); + Assert.assertTrue(Float.isNaN(GammaUtility.logGamma(-0.5F)) == Double.isNaN(Gamma.logGamma(-0.5D))); + Assert.assertTrue(Float.isNaN(GammaUtility.logGamma(-1F)) == Double.isNaN(Gamma.logGamma(-1D))); + + Assert.assertTrue(MathUtility.equal(GammaUtility.logGamma(0.1F), (float) Gamma.logGamma(0.1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.logGamma(0.5f), (float) Gamma.logGamma(0.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.logGamma(1F), (float) Gamma.logGamma(1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.logGamma(1.5F), (float) Gamma.logGamma(1.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.logGamma(8F), (float) Gamma.logGamma(8D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.logGamma(8.1F), (float) Gamma.logGamma(8.1D))); + + Assert.assertTrue(MathUtility.equal(GammaUtility.gamma(0.1F), (float) Gamma.gamma(0.1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.gamma(0.5F), (float) Gamma.gamma(0.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.gamma(1F), (float) Gamma.gamma(1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.gamma(1.5F), (float) Gamma.gamma(1.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.gamma(8F), (float) Gamma.gamma(8D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.gamma(8.1F), (float) Gamma.gamma(8.1D))); + } + + @Test + // TODO 通过https://www.wolframalpha.com/input验证,Apache比较准确. + public void testDigamma() { + // digamma遇到负整数会变为NaN或者无穷. + Assert.assertTrue(Float.isNaN(GammaUtility.digamma(0F)) == Double.isNaN(Gamma.digamma(0D))); + Assert.assertTrue(Float.isNaN(GammaUtility.digamma(-0F)) == Double.isNaN(Gamma.digamma(-0D))); + Assert.assertTrue(Float.isNaN(GammaUtility.digamma(-1F)) == Double.isNaN(Gamma.digamma(-1D))); + + Assert.assertTrue(MathUtility.equal(GammaUtility.digamma(0.1F), (float) Gamma.digamma(0.1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.inverse(GammaUtility.digamma(0.1F), 5), 0.1F)); + Assert.assertTrue(MathUtility.equal(GammaUtility.digamma(0.5F), (float) Gamma.digamma(0.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.inverse(GammaUtility.digamma(0.5F), 5), 0.5F)); + Assert.assertTrue(MathUtility.equal(GammaUtility.digamma(1F), (float) Gamma.digamma(1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.inverse(GammaUtility.digamma(1F), 5), 1F)); + Assert.assertTrue(MathUtility.equal(GammaUtility.digamma(1.5F), (float) Gamma.digamma(1.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.inverse(GammaUtility.digamma(1.5F), 5), 1.5F)); + Assert.assertTrue(MathUtility.equal(GammaUtility.digamma(8F), (float) Gamma.digamma(8D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.inverse(GammaUtility.digamma(8F), 5), 8F)); + Assert.assertTrue(MathUtility.equal(GammaUtility.digamma(8.1F), (float) Gamma.digamma(8.1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.inverse(GammaUtility.digamma(8.1F), 5), 8.1F)); + + Assert.assertTrue(MathUtility.equal(GammaUtility.digamma(-0.1F), (float) Gamma.digamma(-0.1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.digamma(-0.5F), (float) Gamma.digamma(-0.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.digamma(-1.5F), (float) Gamma.digamma(-1.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.digamma(-8.1F), (float) Gamma.digamma(-8.1D))); + } + + @Test + // TODO 通过https://www.wolframalpha.com/input验证,Apache比较准确. + public void testTrigamma() { + // trigamma遇到负整数会变为NaN或者无穷. + Assert.assertTrue(Float.isNaN(GammaUtility.trigamma(0F)) == Double.isNaN(Gamma.trigamma(0D))); + Assert.assertTrue(Float.isNaN(GammaUtility.trigamma(-0F)) == Double.isNaN(Gamma.trigamma(-0D))); + Assert.assertTrue(Float.isNaN(GammaUtility.trigamma(-1F)) == Double.isNaN(Gamma.trigamma(-1D))); + + Assert.assertTrue(MathUtility.equal(GammaUtility.trigamma(0.1F), (float) Gamma.trigamma(0.1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.trigamma(0.5F), (float) Gamma.trigamma(0.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.trigamma(1F), (float) Gamma.trigamma(1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.trigamma(1.5F), (float) Gamma.trigamma(1.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.trigamma(8F), (float) Gamma.trigamma(8D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.trigamma(8.1F), (float) Gamma.trigamma(8.1D))); + + Assert.assertTrue(MathUtility.equal(GammaUtility.trigamma(-0.1F), (float) Gamma.trigamma(-0.1D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.trigamma(-0.5F), (float) Gamma.trigamma(-0.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.trigamma(-1.5F), (float) Gamma.trigamma(-1.5D))); + Assert.assertTrue(MathUtility.equal(GammaUtility.trigamma(-8.1F), (float) Gamma.trigamma(-8.1D))); + } + + public void test() { + // log遇到正负零会变为无穷. + System.out.println(Math.exp(0)); + System.out.println(Math.log(Math.exp(0))); + System.out.println(Math.log(Math.exp(1))); + System.out.println(Math.log(Math.exp(2))); + System.out.println(Math.log(Math.exp(3))); + System.out.println(Math.log(Math.exp(-3))); + + // psiGamma遇到负整数会变为NaN或者无穷. + // logGamma遇到负数会变为NaN或者无穷. + System.out.println(GammaUtility.logGamma(1)); + System.out.println(GammaUtility.logGamma(2)); + System.out.println(GammaUtility.logGamma(3)); + System.out.println(GammaUtility.logGamma(-1)); + System.out.println(GammaUtility.logGamma(-2)); + System.out.println(GammaUtility.logGamma(-3)); + + // TODO 准备将PolyGamma与Gamma整合到GammaUtility. + System.out.println(PolyGamma.psigamma(-1.1D, 0)); + System.out.println(PolyGamma.psigamma(-1D, 0)); + System.out.println(PolyGamma.psigamma(-1.1D, 1)); + System.out.println(PolyGamma.psigamma(-1D, 1)); + System.out.println(PolyGamma.psigamma(-1.1D, 2)); + System.out.println(PolyGamma.psigamma(-1D, 2)); + } + +} diff --git a/src/test/java/com/jstarcraft/rns/utility/SampleUtilityTestCase.java b/src/test/java/com/jstarcraft/rns/utility/SampleUtilityTestCase.java new file mode 100644 index 0000000..7763ec2 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/utility/SampleUtilityTestCase.java @@ -0,0 +1,41 @@ +package com.jstarcraft.rns.utility; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.math.structure.vector.ArrayVector; + +public class SampleUtilityTestCase { + + @Test + public void testBinarySearch() { + int[] indexes = new int[] { 5, 10, 15 }; + float[] values = new float[] { 5F, 10F, 15F }; + + for (float index = 0F; index < 5F; index += 0.5F) { + Assert.assertEquals(0, SampleUtility.binarySearch(values, 0, values.length - 1, index)); + } + + for (float index = 5F; index < 10F; index += 0.5F) { + Assert.assertEquals(1, SampleUtility.binarySearch(values, 0, values.length - 1, index)); + } + + for (float index = 10; index < 15F; index += 0.5F) { + Assert.assertEquals(2, SampleUtility.binarySearch(values, 0, values.length - 1, index)); + } + + ArrayVector vector = new ArrayVector(3, indexes, values); + for (float index = 0F; index < 5F; index += 0.5F) { + Assert.assertEquals(0, SampleUtility.binarySearch(vector, 0, values.length - 1, index)); + } + + for (float index = 5F; index < 10F; index += 0.5F) { + Assert.assertEquals(1, SampleUtility.binarySearch(vector, 0, values.length - 1, index)); + } + + for (float index = 10; index < 15F; index += 0.5F) { + Assert.assertEquals(2, SampleUtility.binarySearch(vector, 0, values.length - 1, index)); + } + } + +} diff --git a/src/test/java/com/jstarcraft/rns/utility/SearchUtilityTestCase.java b/src/test/java/com/jstarcraft/rns/utility/SearchUtilityTestCase.java new file mode 100644 index 0000000..457e1a3 --- /dev/null +++ b/src/test/java/com/jstarcraft/rns/utility/SearchUtilityTestCase.java @@ -0,0 +1,79 @@ +package com.jstarcraft.rns.utility; + +import java.util.concurrent.Future; + +import org.junit.Assert; +import org.junit.Test; + +import com.jstarcraft.ai.environment.EnvironmentContext; +import com.jstarcraft.ai.environment.EnvironmentFactory; +import com.jstarcraft.ai.math.structure.MathCalculator; +import com.jstarcraft.ai.math.structure.matrix.DenseMatrix; +import com.jstarcraft.ai.math.structure.matrix.HashMatrix; +import com.jstarcraft.ai.math.structure.matrix.MatrixScalar; +import com.jstarcraft.ai.math.structure.matrix.SparseMatrix; + +import it.unimi.dsi.fastutil.floats.Float2IntAVLTreeMap; +import it.unimi.dsi.fastutil.floats.Float2IntSortedMap; +import it.unimi.dsi.fastutil.longs.Long2FloatRBTreeMap; + +public class SearchUtilityTestCase { + + @Test + public void testPageRank() throws Exception { + testPageRank(MathCalculator.SERIAL); + testPageRank(MathCalculator.PARALLEL); + } + + private void testPageRank(MathCalculator mode) throws Exception { + EnvironmentContext context = EnvironmentFactory.getContext(); + Future task = context.doTask(() -> { + int dimension = 7; + HashMatrix table = new HashMatrix(true, dimension, dimension, new Long2FloatRBTreeMap()); + table.setValue(0, 1, 0.5F); + table.setValue(0, 2, 0.5F); + + table.setValue(2, 0, 0.3F); + table.setValue(2, 1, 0.3F); + table.setValue(2, 4, 0.3F); + + table.setValue(3, 4, 0.5F); + table.setValue(3, 5, 0.5F); + + table.setValue(4, 3, 0.5F); + table.setValue(4, 5, 0.5F); + + table.setValue(5, 3, 1F); + + table.setValue(6, 1, 0.5F); + table.setValue(6, 3, 0.5F); + SparseMatrix sparseMatrix = SparseMatrix.valueOf(dimension, dimension, table); + DenseMatrix denseMatrix = DenseMatrix.valueOf(dimension, dimension); + for (MatrixScalar scalar : sparseMatrix) { + denseMatrix.setValue(scalar.getRow(), scalar.getColumn(), scalar.getValue()); + } + + Float2IntSortedMap sparseSort = new Float2IntAVLTreeMap(); + { + int index = 0; + for (float score : SearchUtility.pageRank(mode, dimension, sparseMatrix)) { + sparseSort.put(score, index++); + } + } + Assert.assertArrayEquals(new int[] { 6, 0, 2, 1, 4, 5, 3 }, sparseSort.values().toIntArray()); + + Float2IntSortedMap denseSort = new Float2IntAVLTreeMap(); + { + int index = 0; + for (float score : SearchUtility.pageRank(mode, dimension, denseMatrix)) { + denseSort.put(score, index++); + } + } + Assert.assertArrayEquals(new int[] { 6, 0, 2, 1, 4, 5, 3 }, denseSort.values().toIntArray()); + + Assert.assertTrue(sparseSort.equals(denseSort)); + }); + task.get(); + } + +} diff --git a/src/test/resources/com/jstarcraft/rns/script/Model.bsh b/src/test/resources/com/jstarcraft/rns/script/Model.bsh new file mode 100644 index 0000000..5afb5fb --- /dev/null +++ b/src/test/resources/com/jstarcraft/rns/script/Model.bsh @@ -0,0 +1,24 @@ +// 构建配置 +keyValues = new Properties(); +keyValues.load(loader.getResourceAsStream("data.properties")); +keyValues.load(loader.getResourceAsStream("model/benchmark/randomguess-test.properties")); +option = new Option(keyValues); + +// 此对象会返回给Java程序 +_data = new HashMap(); + +// 构建排序任务 +task = new RankingTask(RandomGuessModel.class, option); +// 训练与评估模型并获取排序指标 +measures = task.execute(); +_data.put("precision", measures.get(PrecisionEvaluator.class)); +_data.put("recall", measures.get(RecallEvaluator.class)); + +// 构建评分任务 +task = new RatingTask(RandomGuessModel.class, option); +// 训练与评估模型并获取评分指标 +measures = task.execute(); +_data.put("mae", measures.get(MAEEvaluator.class)); +_data.put("mse", measures.get(MSEEvaluator.class)); + +_data; \ No newline at end of file diff --git a/src/test/resources/com/jstarcraft/rns/script/Model.groovy b/src/test/resources/com/jstarcraft/rns/script/Model.groovy new file mode 100644 index 0000000..84e5727 --- /dev/null +++ b/src/test/resources/com/jstarcraft/rns/script/Model.groovy @@ -0,0 +1,24 @@ +// 构建配置 +def keyValues = new Properties(); +keyValues.load(loader.getResourceAsStream("data.properties")); +keyValues.load(loader.getResourceAsStream("model/benchmark/randomguess-test.properties")); +def option = new Option(keyValues); + +// 此对象会返回给Java程序 +def _data = [:]; + +// 构建排序任务 +task = new RankingTask(RandomGuessModel.class, option); +// 训练与评估模型并获取排序指标 +measures = task.execute(); +_data.precision = measures.get(PrecisionEvaluator.class); +_data.recall = measures.get(RecallEvaluator.class); + +// 构建评分任务 +task = new RatingTask(RandomGuessModel.class, option); +// 训练与评估模型并获取评分指标 +measures = task.execute(); +_data.mae = measures.get(MAEEvaluator.class); +_data.mse = measures.get(MSEEvaluator.class); + +_data; \ No newline at end of file diff --git a/src/test/resources/com/jstarcraft/rns/script/Model.js b/src/test/resources/com/jstarcraft/rns/script/Model.js new file mode 100644 index 0000000..adb7437 --- /dev/null +++ b/src/test/resources/com/jstarcraft/rns/script/Model.js @@ -0,0 +1,24 @@ +// 构建配置 +var keyValues = new Properties(); +keyValues.load(loader.getResourceAsStream("data.properties")); +keyValues.load(loader.getResourceAsStream("model/benchmark/randomguess-test.properties")); +var option = new Option(keyValues); + +// 此对象会返回给Java程序 +var _data = {}; + +// 构建排序任务 +task = new RankingTask(RandomGuessModel.class, option); +// 训练与评估模型并获取排序指标 +measures = task.execute(); +_data['precision'] = measures.get(PrecisionEvaluator.class); +_data['recall'] = measures.get(RecallEvaluator.class); + +// 构建评分任务 +task = new RatingTask(RandomGuessModel.class, option); +// 训练与评估模型并获取评分指标 +measures = task.execute(); +_data['mae'] = measures.get(MAEEvaluator.class); +_data['mse'] = measures.get(MSEEvaluator.class); + +_data; \ No newline at end of file diff --git a/src/test/resources/com/jstarcraft/rns/script/Model.kt b/src/test/resources/com/jstarcraft/rns/script/Model.kt new file mode 100644 index 0000000..a7a8bf8 --- /dev/null +++ b/src/test/resources/com/jstarcraft/rns/script/Model.kt @@ -0,0 +1,25 @@ +// 构建配置 +var keyValues = Properties(); +var loader = bindings["loader"] as ClassLoader; +keyValues.load(loader.getResourceAsStream("data.properties")); +keyValues.load(loader.getResourceAsStream("model/benchmark/randomguess-test.properties")); +var option = Option(keyValues); + +// 此对象会返回给Java程序 +var _data = mutableMapOf(); + +// 构建排序任务 +var rankingTask = RankingTask(RandomGuessModel::class.java, option); +// 训练与评估模型并获取排序指标 +val rankingMeasures = rankingTask.execute(); +_data["precision"] = rankingMeasures.getFloat(PrecisionEvaluator::class.java); +_data["recall"] = rankingMeasures.getFloat(RecallEvaluator::class.java); + +// 构建评分任务 +var ratingTask = RatingTask(RandomGuessModel::class.java, option); +// 训练与评估模型并获取评分指标 +var ratingMeasures = ratingTask.execute(); +_data["mae"] = ratingMeasures.getFloat(MAEEvaluator::class.java); +_data["mse"] = ratingMeasures.getFloat(MSEEvaluator::class.java); + +_data; \ No newline at end of file diff --git a/src/test/resources/com/jstarcraft/rns/script/Model.lua b/src/test/resources/com/jstarcraft/rns/script/Model.lua new file mode 100644 index 0000000..8a7c0de --- /dev/null +++ b/src/test/resources/com/jstarcraft/rns/script/Model.lua @@ -0,0 +1,24 @@ +-- 构建配置 +local keyValues = Properties.new(); +keyValues:load(loader:getResourceAsStream("data.properties")); +keyValues:load(loader:getResourceAsStream("model/benchmark/randomguess-test.properties")); +local option = Option.new(keyValues); + +-- 此对象会返回给Java程序 +local _data = {}; + +-- 构建排序任务 +task = RankingTask.new(RandomGuessModel, option); +-- 训练与评估模型并获取排序指标 +measures = task:execute(); +_data["precision"] = measures:get(PrecisionEvaluator); +_data["recall"] = measures:get(RecallEvaluator); + +-- 构建评分任务 +task = RatingTask.new(RandomGuessModel, option); +-- 训练与评估模型并获取评分指标 +measures = task:execute(); +_data["mae"] = measures:get(MAEEvaluator); +_data["mse"] = measures:get(MSEEvaluator); + +return _data; \ No newline at end of file diff --git a/src/test/resources/com/jstarcraft/rns/script/Model.py b/src/test/resources/com/jstarcraft/rns/script/Model.py new file mode 100644 index 0000000..ed21da9 --- /dev/null +++ b/src/test/resources/com/jstarcraft/rns/script/Model.py @@ -0,0 +1,22 @@ +# 构建配置 +keyValues = Properties() +keyValues.load(loader.getResourceAsStream("data.properties")) +keyValues.load(loader.getResourceAsStream("model/benchmark/randomguess-test.properties")) +option = Option(keyValues) + +# 此对象会返回给Java程序 +_data = {} + +# 构建排序任务 +task = RankingTask(RandomGuessModel, option) +# 训练与评估模型并获取排序指标 +measures = task.execute() +_data['precision'] = measures.get(PrecisionEvaluator) +_data['recall'] = measures.get(RecallEvaluator) + +# 构建评分任务 +task = RatingTask(RandomGuessModel, option) +# 训练与评估模型并获取评分指标 +measures = task.execute() +_data['mae'] = measures.get(MAEEvaluator) +_data['mse'] = measures.get(MSEEvaluator) \ No newline at end of file diff --git a/src/test/resources/com/jstarcraft/rns/script/Model.rb b/src/test/resources/com/jstarcraft/rns/script/Model.rb new file mode 100644 index 0000000..23bfb2f --- /dev/null +++ b/src/test/resources/com/jstarcraft/rns/script/Model.rb @@ -0,0 +1,24 @@ +# 构建配置 +keyValues = Properties.new() +keyValues.load($loader.getResourceAsStream("data.properties")) +keyValues.load($loader.getResourceAsStream("model/benchmark/randomguess-test.properties")) +option = Option.new(keyValues) + +# 此对象会返回给Java程序 +_data = Hash.new() + +# 构建排序任务 +task = RankingTask.new(RandomGuessModel.java_class, option) +# 训练与评估模型并获取排序指标 +measures = task.execute() +_data['precision'] = measures.get(PrecisionEvaluator.java_class) +_data['recall'] = measures.get(RecallEvaluator.java_class) + +# 构建评分任务 +task = RatingTask.new(RandomGuessModel.java_class, option) +# 训练与评估模型并获取评分指标 +measures = task.execute() +_data['mae'] = measures.get(MAEEvaluator.java_class) +_data['mse'] = measures.get(MSEEvaluator.java_class) + +_data; \ No newline at end of file diff --git a/src/test/resources/data.properties b/src/test/resources/data.properties new file mode 100644 index 0000000..b680504 --- /dev/null +++ b/src/test/resources/data.properties @@ -0,0 +1,14 @@ +#TODO 区分环境配置与算法配置 +data.attributes.dicrete={"user":"int","item":"int","instant":"int"} +data.attributes.continuous={"score":"float"} + +data.modules=[{"name":"score","configuration":{"1":"user","2":"item","3":"instant","4":"score"}}] +data.converters=[{"name":"score","type":"csv","path":"data/filmtrust/score.txt"}] +data.separator={"name":"score","type":"ratio","matchField":null,"sortField":null} + +data.model.fields.user=user +data.model.fields.item=item +data.model.fields.instant=instant +data.model.fields.score=score + +recommender.random.seed=0 \ No newline at end of file diff --git a/src/test/resources/data/Foursquare.properties b/src/test/resources/data/Foursquare.properties new file mode 100644 index 0000000..e90fc3f --- /dev/null +++ b/src/test/resources/data/Foursquare.properties @@ -0,0 +1,15 @@ +data.attributes.dicrete={"user":"java.lang.String","item":"java.lang.String"} +data.attributes.continuous={"score":"float","longitude":"float","latitude":"float"} + +data.modules=[{"name":"score","configuration":{"1":"user","2":"item","3":"score"}},{"name":"location","configuration":{"1":"item","2":"longitude","3":"latitude"}}] +data.converters=[{"name":"score","type":"csv","path":"data/poi/FourSquare/checkin/FourSquare.txt"},{"name":"location","type":"csv","path":"data/poi/FourSquare/Location.txt"}] +data.separator={"name":"score","type":"ratio","matchField":null,"sortField":null} +data.separator.threshold=69159 + +data.model.fields.user=user +data.model.fields.item=item +data.model.fields.score=score +data.model.fields.longitude=longitude +data.model.fields.latitude=latitude + +recommender.random.seed=0 diff --git a/src/test/resources/data/dc_dense.properties b/src/test/resources/data/dc_dense.properties new file mode 100644 index 0000000..3a52e67 --- /dev/null +++ b/src/test/resources/data/dc_dense.properties @@ -0,0 +1,14 @@ +data.attributes.dicrete={"user":"int","item":"int","comment":"java.lang.String"} +data.attributes.continuous={"score":"float"} + +data.modules=[{"name":"score","configuration":{"1":"user","2":"item","3":"comment","4":"score"}}] +data.converters=[{"name":"score","type":"csv","path":"data/dc_dense.txt"}] +data.separator={"name":"score","type":"ratio","matchField":null,"sortField":null} +data.separator.delimiter=, + +data.model.fields.user=user +data.model.fields.item=item +data.model.fields.score=score +data.model.fields.comment=comment + +recommender.random.seed=0 diff --git a/src/test/resources/data/filmtrust.properties b/src/test/resources/data/filmtrust.properties new file mode 100644 index 0000000..0da3d4d --- /dev/null +++ b/src/test/resources/data/filmtrust.properties @@ -0,0 +1,14 @@ +data.attributes.dicrete={"user":"int","item":"int","instant":"int"} +data.attributes.continuous={"score":"float","coefficient":"float"} + +data.modules=[{"name":"score","configuration":{"1":"user","2":"item","3":"instant","4":"score"}},{"name":"social","configuration":{"2":"user","3":"coefficient"}}] +data.converters=[{"name":"score","type":"csv","path":"data/filmtrust/score.txt"},{"name":"social","type":"csv","path":"data/filmtrust/trust.txt"}] +data.separator={"name":"score","type":"ratio","matchField":"user","sortField":"instant"} + +data.model.fields.user=user +data.model.fields.item=item +data.model.fields.instant=instant +data.model.fields.score=score +data.model.fields.coefficient=coefficient + +recommender.random.seed=0 diff --git a/src/test/resources/data/game.properties b/src/test/resources/data/game.properties new file mode 100644 index 0000000..a98510c --- /dev/null +++ b/src/test/resources/data/game.properties @@ -0,0 +1,14 @@ +data.attributes.dicrete={"user":"long","item":"int","instant":"long","level":"int"} +data.attributes.continuous={"score":"float"} + +data.modules=[{"name":"score","configuration":{"1":"user","2":"item","3":"instant","4":"level","5":"score"}}] +data.converters=[{"name":"score","type":"csv","path":"data/game/score.txt"}] +data.separator={"name":"score","type":"ratio","matchField":"user","sortField":"instant"} + +data.model.fields.user=user +data.model.fields.item=item +data.model.fields.instant=instant +data.model.fields.score=score +data.model.fields.context=level + +recommender.random.seed=0 diff --git a/src/test/resources/data/ml100k.properties b/src/test/resources/data/ml100k.properties new file mode 100644 index 0000000..0a845f9 --- /dev/null +++ b/src/test/resources/data/ml100k.properties @@ -0,0 +1,14 @@ +data.attributes.dicrete={"user":"int","item":"int","instant":"int"} +data.attributes.continuous={"score":"float","coefficient":"float"} + +data.modules=[{"name":"score","configuration":{"1":"user","2":"item","3":"instant","4":"score"}}] +data.converters=[{"name":"score","type":"csv","path":"data/movielens/ml-100k/score.txt"}] +data.separator={"name":"score","type":"ratio","matchField":"user","sortField":"instant"} + +data.model.fields.user=user +data.model.fields.item=item +data.model.fields.instant=instant +data.model.fields.score=score +data.model.fields.coefficient=coefficient + +recommender.random.seed=0 diff --git a/src/test/resources/data/musical_instruments.properties b/src/test/resources/data/musical_instruments.properties new file mode 100644 index 0000000..7c2e859 --- /dev/null +++ b/src/test/resources/data/musical_instruments.properties @@ -0,0 +1,14 @@ +data.attributes.dicrete={"user":"int","item":"int","comment":"java.lang.String"} +data.attributes.continuous={"score":"float"} + +data.modules=[{"name":"score","configuration":{"1":"user","2":"item","3":"comment","4":"score"}}] +data.converters=[{"name":"score","type":"csv","path":"data/musical_instruments.txt"}] +data.separator={"name":"score","type":"ratio","matchField":null,"sortField":null} +data.separator.delimiter=, + +data.model.fields.user=user +data.model.fields.item=item +data.model.fields.score=score +data.model.fields.comment=comment + +recommender.random.seed=0 diff --git a/src/test/resources/data/product.properties b/src/test/resources/data/product.properties new file mode 100644 index 0000000..c8fbc50 --- /dev/null +++ b/src/test/resources/data/product.properties @@ -0,0 +1,19 @@ +data.attributes.dicrete={"user":"java.lang.String","item":"java.lang.String","feature":"int"} +data.attributes.continuous={"score":"float","degree":"float","coefficient":"float"} + +data.modules=[{"name":"score","configuration":{"1":"user","2":"item","3":"score"}},{"name":"article","configuration":{"1":"item","2":"feature","3":"degree"}},{"name":"relation","configuration":{"2":"item","3":"coefficient"}}] +data.converters=[{"name":"score","type":"csv","path":"data/product/scores.txt"},{"name":"article","type":"csv","path":"data/product/features.txt"},{"name":"relation","type":"csv","path":"data/product/relation.txt"}] +data.separator={"name":"score","type":"ratio","matchField":null,"sortField":null} +data.separator.delimiter=, + +data.model.fields.user=user +data.model.fields.item=item +data.model.fields.score=score +data.model.fields.article=item +data.model.fields.feature=feature +data.model.fields.degree=degree +data.model.fields.left=item +data.model.fields.right=item +data.model.fields.coefficient=coefficient + +recommender.random.seed=0 diff --git a/src/test/resources/log4j2.xml b/src/test/resources/log4j2.xml new file mode 100644 index 0000000..d6ae0c0 --- /dev/null +++ b/src/test/resources/log4j2.xml @@ -0,0 +1,77 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/test/resources/model/benchmark/constantguess-test.properties b/src/test/resources/model/benchmark/constantguess-test.properties new file mode 100644 index 0000000..e69de29 diff --git a/src/test/resources/model/benchmark/globalaverage-test.properties b/src/test/resources/model/benchmark/globalaverage-test.properties new file mode 100644 index 0000000..e69de29 diff --git a/src/test/resources/model/benchmark/itemaverage-test.properties b/src/test/resources/model/benchmark/itemaverage-test.properties new file mode 100644 index 0000000..e69de29 diff --git a/src/test/resources/model/benchmark/itemcluster-test.properties b/src/test/resources/model/benchmark/itemcluster-test.properties new file mode 100644 index 0000000..7c02c43 --- /dev/null +++ b/src/test/resources/model/benchmark/itemcluster-test.properties @@ -0,0 +1,2 @@ +recommender.topic.number=10 +recommender.iterator.maximum=20 \ No newline at end of file diff --git a/src/test/resources/model/benchmark/mostpopular-test.properties b/src/test/resources/model/benchmark/mostpopular-test.properties new file mode 100644 index 0000000..26461c8 --- /dev/null +++ b/src/test/resources/model/benchmark/mostpopular-test.properties @@ -0,0 +1 @@ +recommender.recommender.isranking=true \ No newline at end of file diff --git a/src/test/resources/model/benchmark/randomguess-test.properties b/src/test/resources/model/benchmark/randomguess-test.properties new file mode 100644 index 0000000..e69de29 diff --git a/src/test/resources/model/benchmark/useraverage-test.properties b/src/test/resources/model/benchmark/useraverage-test.properties new file mode 100644 index 0000000..e69de29 diff --git a/src/test/resources/model/benchmark/usercluster-test.properties b/src/test/resources/model/benchmark/usercluster-test.properties new file mode 100644 index 0000000..7c02c43 --- /dev/null +++ b/src/test/resources/model/benchmark/usercluster-test.properties @@ -0,0 +1,2 @@ +recommender.topic.number=10 +recommender.iterator.maximum=20 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/bhfreeranking-test.properties b/src/test/resources/model/collaborative/bhfreeranking-test.properties new file mode 100644 index 0000000..a0d467a --- /dev/null +++ b/src/test/resources/model/collaborative/bhfreeranking-test.properties @@ -0,0 +1,11 @@ +recommender.pgm.burnin=10 +recommender.pgm.samplelag=10 +recommender.iterator.maximum=100 + +recommender.bhfree.alpha=0.01 +recommender.bhfree.beta=0.01 +recommender.bhfree.gamma=0.01 +recommender.bhfree.sigma=0.01 + +recommender.bhfree.user.topic.number=10 +recommender.bhfree.item.topic.number=10 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/bhfreerating-test.properties b/src/test/resources/model/collaborative/bhfreerating-test.properties new file mode 100644 index 0000000..a0d467a --- /dev/null +++ b/src/test/resources/model/collaborative/bhfreerating-test.properties @@ -0,0 +1,11 @@ +recommender.pgm.burnin=10 +recommender.pgm.samplelag=10 +recommender.iterator.maximum=100 + +recommender.bhfree.alpha=0.01 +recommender.bhfree.beta=0.01 +recommender.bhfree.gamma=0.01 +recommender.bhfree.sigma=0.01 + +recommender.bhfree.user.topic.number=10 +recommender.bhfree.item.topic.number=10 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/bucmranking-test.properties b/src/test/resources/model/collaborative/bucmranking-test.properties new file mode 100644 index 0000000..401a29e --- /dev/null +++ b/src/test/resources/model/collaborative/bucmranking-test.properties @@ -0,0 +1,8 @@ +recommender.pgm.burnin=10 +recommender.pgm.samplelag=10 + +recommender.iterator.maximum=100 +recommender.pgm.topic.number=10 +recommender.bucm.alpha=0.01 +recommender.bucm.beta=0.01 +recommender.bucm.gamma=0.01 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/bucmrating-test.properties b/src/test/resources/model/collaborative/bucmrating-test.properties new file mode 100644 index 0000000..401a29e --- /dev/null +++ b/src/test/resources/model/collaborative/bucmrating-test.properties @@ -0,0 +1,8 @@ +recommender.pgm.burnin=10 +recommender.pgm.samplelag=10 + +recommender.iterator.maximum=100 +recommender.pgm.topic.number=10 +recommender.bucm.alpha=0.01 +recommender.bucm.beta=0.01 +recommender.bucm.gamma=0.01 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/itemknnranking-test.properties b/src/test/resources/model/collaborative/itemknnranking-test.properties new file mode 100644 index 0000000..227c3ae --- /dev/null +++ b/src/test/resources/model/collaborative/itemknnranking-test.properties @@ -0,0 +1,2 @@ +recommender.correlation.class=com.jstarcraft.ai.math.algorithm.correlation.similarity.JaccardIndexSimilarity +recommender.neighbors.knn.number=50 diff --git a/src/test/resources/model/collaborative/itemknnrating-test.properties b/src/test/resources/model/collaborative/itemknnrating-test.properties new file mode 100644 index 0000000..6ab5d01 --- /dev/null +++ b/src/test/resources/model/collaborative/itemknnrating-test.properties @@ -0,0 +1,2 @@ +recommender.correlation.class=com.jstarcraft.ai.math.algorithm.correlation.similarity.PCCSimilarity +recommender.neighbors.knn.number=50 diff --git a/src/test/resources/model/collaborative/ranking/aobpr-test.properties b/src/test/resources/model/collaborative/ranking/aobpr-test.properties new file mode 100644 index 0000000..0016f79 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/aobpr-test.properties @@ -0,0 +1,9 @@ +recommender.item.distribution.parameter=0.5 +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=100 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/ranking/aspectmodelranking-test.properties b/src/test/resources/model/collaborative/ranking/aspectmodelranking-test.properties new file mode 100644 index 0000000..f68ead7 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/aspectmodelranking-test.properties @@ -0,0 +1,4 @@ +recommender.iterator.maximum=20 +recommender.pgm.burnin=10 +recommender.pgm.samplelag=10 +recommender.topic.number=10 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/bpr-test.properties b/src/test/resources/model/collaborative/ranking/bpr-test.properties new file mode 100644 index 0000000..be0bdaf --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/bpr-test.properties @@ -0,0 +1,8 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=100 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 +recommender.learnRate.bolddriver=false +recommender.learnRate.decay=1.0 diff --git a/src/test/resources/model/collaborative/ranking/cdae-test.properties b/src/test/resources/model/collaborative/ranking/cdae-test.properties new file mode 100644 index 0000000..bfd4668 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/cdae-test.properties @@ -0,0 +1,11 @@ +recommender.iterator.learnrate=0.1 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.momentum=0.9 +recommender.iterator.maximum=100 +recommender.weight.regularization=0.01 +recommender.hidden.dimension=200 +recommender.hidden.activation=sigmoid +recommender.output.activation=identity +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.binarize.threshold=3 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/climf-test.properties b/src/test/resources/model/collaborative/ranking/climf-test.properties new file mode 100644 index 0000000..161bd77 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/climf-test.properties @@ -0,0 +1,11 @@ +recommender.iterator.learnrate=0.001 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=100 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 + +recommender.init.mean=0.0 +recommender.init.std=0.01 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/deepcross-test.properties b/src/test/resources/model/collaborative/ranking/deepcross-test.properties new file mode 100644 index 0000000..bd5b18d --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/deepcross-test.properties @@ -0,0 +1,9 @@ +recommender.iterator.learnrate=0.05 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.momentum=0.9 +recommender.iterator.maximum=10 +recommender.weight.regularization=0.01 + +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.binarize.threshold=3 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/deepfm-test.properties b/src/test/resources/model/collaborative/ranking/deepfm-test.properties new file mode 100644 index 0000000..e4119ef --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/deepfm-test.properties @@ -0,0 +1,9 @@ +recommender.iterator.learnrate=0.05 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.momentum=0.9 +recommender.iterator.maximum=20 +recommender.weight.regularization=0.01 + +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.binarize.threshold=3 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/eals-test.properties b/src/test/resources/model/collaborative/ranking/eals-test.properties new file mode 100644 index 0000000..a91960d --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/eals-test.properties @@ -0,0 +1,17 @@ +recommender.iterator.maximum=20 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=32 + +#0:eALS MF; 1:WRMF; 2: both +recommender.eals.wrmf.judge=1 + +#the overall weight of missing data c0 +recommender.eals.overall=128 + +#the significance level of popular items over un-popular ones +recommender.eals.ratio=0.4 + +#confidence weight coefficient, alpha in original paper +recommender.wrmf.weight.coefficient=1.0 + diff --git a/src/test/resources/model/collaborative/ranking/fismauc-test.properties b/src/test/resources/model/collaborative/ranking/fismauc-test.properties new file mode 100644 index 0000000..8358bb0 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/fismauc-test.properties @@ -0,0 +1,14 @@ +recommender.iteration.learnrate=0.00001 +recommender.iterator.maximum=5 +recommender.recommender.isranking=true + +recommender.fismauc.rho=0.5 +recommender.fismauc.alpha=0.9 +recommender.fismauc.gamma=0.1 +recommender.factor.number=10 +recommender.user.regularization=0.001 +recommender.item.regularization=0.001 +recommender.bias.regularization=0.001 + +recommender.init.mean=0.0 +recommender.init.std=0.01 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/fismrmse-test.properties b/src/test/resources/model/collaborative/ranking/fismrmse-test.properties new file mode 100644 index 0000000..0ceee4c --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/fismrmse-test.properties @@ -0,0 +1,13 @@ +recommender.iteration.learnrate=0.01 +recommender.iterator.maximum=14 +recommender.recommender.isranking=true + +recommender.fismrmse.rho=1 +recommender.fismrmse.alpha=0.8 +recommender.factor.number=10 +recommender.user.regularization=0.001 +recommender.item.regularization=0.001 +recommender.bias.regularization=0.001 + +recommender.init.mean=0.0 +recommender.init.std=0.01 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/gbpr-test.properties b/src/test/resources/model/collaborative/ranking/gbpr-test.properties new file mode 100644 index 0000000..a1e795f --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/gbpr-test.properties @@ -0,0 +1,11 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=100 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.gpbr.rho=1.5 +recommender.gpbr.gsize=2 + diff --git a/src/test/resources/model/collaborative/ranking/hmm-test.properties b/src/test/resources/model/collaborative/ranking/hmm-test.properties new file mode 100644 index 0000000..9c6f349 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/hmm-test.properties @@ -0,0 +1,6 @@ +recommender.iterator.maximum=10 + +recommender.hmm.state.number=20 +recommender.probability.regularization=40 +recommender.state.regularization=40 +recommender.view.regularization=80000 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/itembigram-test.properties b/src/test/resources/model/collaborative/ranking/itembigram-test.properties new file mode 100644 index 0000000..7bff56a --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/itembigram-test.properties @@ -0,0 +1,7 @@ +recommender.iterator.maximum=100 +recommender.topic.number=10 +recommender.user.dirichlet.prior=0.01 +recommender.topic.dirichlet.prior=0.01 +recommender.pgm.burnin=10 +recommender.pgm.samplelag=10 + diff --git a/src/test/resources/model/collaborative/ranking/lambdafmd-test.properties b/src/test/resources/model/collaborative/ranking/lambdafmd-test.properties new file mode 100644 index 0000000..8062445 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/lambdafmd-test.properties @@ -0,0 +1,8 @@ +model=Dynamic +recommender.item.distribution.parameter=1 +recommender.iterator.learnrate=0.0001 +recommender.iterator.maximum=50 +recommender.factor.number=10 +recommender.fm.regw0=0.001 +recommender.fm.regW=0.001 +recommender.fm.regF=0.001 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/lambdafms-test.properties b/src/test/resources/model/collaborative/ranking/lambdafms-test.properties new file mode 100644 index 0000000..45b4daf --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/lambdafms-test.properties @@ -0,0 +1,8 @@ +model=Static +recommender.item.distribution.parameter=1 +recommender.iterator.learnrate=0.0001 +recommender.iterator.maximum=30 +recommender.factor.number=10 +recommender.fm.regw0=0.001 +recommender.fm.regW=0.001 +recommender.fm.regF=0.001 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/lambdafmw-test.properties b/src/test/resources/model/collaborative/ranking/lambdafmw-test.properties new file mode 100644 index 0000000..17c3084 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/lambdafmw-test.properties @@ -0,0 +1,8 @@ +model=Weight +recommender.iterator.learnrate=0.0001 +recommender.iterator.maximum=30 +recommender.factor.number=10 +recommender.fm.regw0=0.001 +recommender.fm.regW=0.001 +recommender.fm.regF=0.001 +epsilon=10 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/lda-test.properties b/src/test/resources/model/collaborative/ranking/lda-test.properties new file mode 100644 index 0000000..16055ca --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/lda-test.properties @@ -0,0 +1,9 @@ +recommender.iterator.maximum=100 +recommender.topic.number = 10 +recommender.user.dirichlet.prior=0.01 +recommender.topic.dirichlet.prior=0.01 +recommender.pgm.burnin=10 +recommender.pgm.samplelag=10 +# (0.0 maybe a better choose than -1.0) +data.convert.binarize.threshold=0.0 + diff --git a/src/test/resources/model/collaborative/ranking/listrankmf-test.properties b/src/test/resources/model/collaborative/ranking/listrankmf-test.properties new file mode 100644 index 0000000..cab0b1a --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/listrankmf-test.properties @@ -0,0 +1,8 @@ +recommender.iterator.learnrate=1.0 +recommender.iterator.learnrate.maximum=100 +recommender.iterator.maximum=100 +recommender.user.regularization=0.06 +recommender.item.regularization=0.06 +recommender.factor.number=5 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/listwisemf-test.properties b/src/test/resources/model/collaborative/ranking/listwisemf-test.properties new file mode 100644 index 0000000..fbd2d73 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/listwisemf-test.properties @@ -0,0 +1,8 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=100 +recommender.user.regularization=0.00 +recommender.item.regularization=0.00 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/ranking/plsa-test.properties b/src/test/resources/model/collaborative/ranking/plsa-test.properties new file mode 100644 index 0000000..1c3ada4 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/plsa-test.properties @@ -0,0 +1,5 @@ +recommender.iteration.learnrate=0.01 +recommender.iterator.maximum=100 +recommender.topic.number = 10 +# (0.0 maybe a better choose than -1.0) +data.convert.binarize.threshold=0.0 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/rankals-test.properties b/src/test/resources/model/collaborative/ranking/rankals-test.properties new file mode 100644 index 0000000..66da348 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/rankals-test.properties @@ -0,0 +1,10 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=4 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 + +recommender.rankals.support.weight=true diff --git a/src/test/resources/model/collaborative/ranking/rankcd-test.properties b/src/test/resources/model/collaborative/ranking/rankcd-test.properties new file mode 100644 index 0000000..567d355 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/rankcd-test.properties @@ -0,0 +1,14 @@ +recommender.iterator.learnrate=0.001 +recommender.iterator.learnrate.maximum=0.001 +recommender.factor.number=20 +recommender.iterator.maximum=5 + +recommender.init.mean=0.0 +recommender.init.std=0.1 + +recommender.user.regularization=0.1 +recommender.item.regularization=0.1 + +recommender.rankcd.alpha=5 + +data.convert.binarize.threshold=0.0 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/ranksgd-test.properties b/src/test/resources/model/collaborative/ranking/ranksgd-test.properties new file mode 100644 index 0000000..27354ed --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/ranksgd-test.properties @@ -0,0 +1,8 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=30 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/ranking/rankvfcd-test.properties b/src/test/resources/model/collaborative/ranking/rankvfcd-test.properties new file mode 100644 index 0000000..70ff804 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/rankvfcd-test.properties @@ -0,0 +1,17 @@ +recommender.iterator.learnrate=0.001 +recommender.iterator.learnrate.maximum=0.001 +recommender.factor.number=20 +recommender.iterator.maximum=5 + +recommender.init.mean=0.0 +recommender.init.std=0.01 + +recommender.user.regularization=0.1 +recommender.item.regularization=0.1 + +recommender.rankvfcd.alpha=5 +recommender.rankvfcd.beta=10 +recommender.rankvfcd.gamma=50 +recommender.rankvfcd.lamutaE=50 + +data.convert.binarize.threshold=0.0 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/slim-test.properties b/src/test/resources/model/collaborative/ranking/slim-test.properties new file mode 100644 index 0000000..7ba3a88 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/slim-test.properties @@ -0,0 +1,8 @@ +recommender.correlation.class=com.jstarcraft.ai.math.algorithm.correlation.similarity.CosineSimilarity +recommender.iterator.maximum=40 +recommender.neighbors.knn.number=50 +recommender.recommender.earlystop=true + +recommender.slim.regularization.l1=1 +recommender.slim.regularization.l2=5 + diff --git a/src/test/resources/model/collaborative/ranking/vbpr-test.properties b/src/test/resources/model/collaborative/ranking/vbpr-test.properties new file mode 100644 index 0000000..64d29f2 --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/vbpr-test.properties @@ -0,0 +1,15 @@ +recommender.iterator.learnrate=0.001 +recommender.iterator.learnrate.maximum=0.001 +recommender.factor.number=20 +recommender.iterator.maximum=5 + +recommender.init.mean=0.0 +recommender.init.std=0.01 + +recommender.user.regularization=0.1 +recommender.item.regularization=0.1 +recommender.bias.regularization=0.1 + +recommender.vbpr.alpha=10 + +data.convert.binarize.threshold=0.0 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/warpmf-test.properties b/src/test/resources/model/collaborative/ranking/warpmf-test.properties new file mode 100644 index 0000000..db2506e --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/warpmf-test.properties @@ -0,0 +1,7 @@ +recommender.iterator.maximum=20 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 + +#confidence weight coefficient, alpha in original paper +epsilon=0.5 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/ranking/wbpr-test.properties b/src/test/resources/model/collaborative/ranking/wbpr-test.properties new file mode 100644 index 0000000..e6e010e --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/wbpr-test.properties @@ -0,0 +1,9 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=20 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.bias.regularization=0.01 +recommender.factor.number=128 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/ranking/wrmf-test.properties b/src/test/resources/model/collaborative/ranking/wrmf-test.properties new file mode 100644 index 0000000..7d9886d --- /dev/null +++ b/src/test/resources/model/collaborative/ranking/wrmf-test.properties @@ -0,0 +1,7 @@ +recommender.iterator.maximum=20 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 + +#confidence weight coefficient, alpha in original paper +recommender.wrmf.weight.coefficient=4.0 diff --git a/src/test/resources/model/collaborative/rating/aspectmodelrating-test.properties b/src/test/resources/model/collaborative/rating/aspectmodelrating-test.properties new file mode 100644 index 0000000..c300b90 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/aspectmodelrating-test.properties @@ -0,0 +1,2 @@ +recommender.iteration.learnrate=0.01 +recommender.iterator.maximum=10 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/rating/asvdpp-test.properties b/src/test/resources/model/collaborative/rating/asvdpp-test.properties new file mode 100644 index 0000000..c63c61b --- /dev/null +++ b/src/test/resources/model/collaborative/rating/asvdpp-test.properties @@ -0,0 +1,2 @@ +recommender.iteration.learnrate=0.01 +recommender.iterator.maximum=20 diff --git a/src/test/resources/model/collaborative/rating/autorec-test.properties b/src/test/resources/model/collaborative/rating/autorec-test.properties new file mode 100644 index 0000000..98ae706 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/autorec-test.properties @@ -0,0 +1,10 @@ +recommender.iterator.learnrate=0.025 +recommender.iterator.learnrate.maximum=0.015 +recommender.iterator.momentum=0.9 +recommender.iterator.maximum=200 +recommender.weight.regularization=0.001 +recommender.hidden.dimension=200 +recommender.hidden.activation=sigmoid +recommender.output.activation=identity +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/rating/biasedmf-test.properties b/src/test/resources/model/collaborative/rating/biasedmf-test.properties new file mode 100644 index 0000000..21e3e0e --- /dev/null +++ b/src/test/resources/model/collaborative/rating/biasedmf-test.properties @@ -0,0 +1,9 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=10 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.bias.regularization=0.01 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/rating/bpmf-test.properties b/src/test/resources/model/collaborative/rating/bpmf-test.properties new file mode 100644 index 0000000..0f8bbe7 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/bpmf-test.properties @@ -0,0 +1,16 @@ +recommender.iterator.maximum=150 +recommender.factor.number=20 + +recommender.recommender.user.mu=0.0 +recommender.recommender.item.mu=0.0 + +recommender.recommender.user.beta=1.0 +recommender.recommender.item.beta=1.0 + +recommender.recommender.user.wishart.scale=1.0 +recommender.recommender.item.wishart.scale=1.0 + +recommender.recommender.rating.sigma=2.0 + +# recommender.learnrate.bolddriver=false +# recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/rating/bpoissmf-test.properties b/src/test/resources/model/collaborative/rating/bpoissmf-test.properties new file mode 100644 index 0000000..6a02ca7 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/bpoissmf-test.properties @@ -0,0 +1,8 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=100 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/rating/ccd-test.properties b/src/test/resources/model/collaborative/rating/ccd-test.properties new file mode 100644 index 0000000..7eb0054 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/ccd-test.properties @@ -0,0 +1,10 @@ +recommender.init.mean=0.0 +recommender.init.std=0.1 +recommender.iterator.maximum=500 +recommender.user.regularization=1.0 +recommender.item.regularization=1.0 + +recommender.factor.number=20 + +recommender.recommender.rating.minimum=1.0 +recommender.recommender.rating.maximum=5.0 diff --git a/src/test/resources/model/collaborative/rating/ffm-test.properties b/src/test/resources/model/collaborative/rating/ffm-test.properties new file mode 100644 index 0000000..dc7fd4b --- /dev/null +++ b/src/test/resources/model/collaborative/rating/ffm-test.properties @@ -0,0 +1,3 @@ +recommender.iterator.learnrate=0.001 +recommender.iterator.maximum=100 +recommender.factor.number=10 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/rating/fmals-test.properties b/src/test/resources/model/collaborative/rating/fmals-test.properties new file mode 100644 index 0000000..45a62d3 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/fmals-test.properties @@ -0,0 +1,3 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.maximum=100 +recommender.factor.number=10 diff --git a/src/test/resources/model/collaborative/rating/fmsgd-test.properties b/src/test/resources/model/collaborative/rating/fmsgd-test.properties new file mode 100644 index 0000000..dc7fd4b --- /dev/null +++ b/src/test/resources/model/collaborative/rating/fmsgd-test.properties @@ -0,0 +1,3 @@ +recommender.iterator.learnrate=0.001 +recommender.iterator.maximum=100 +recommender.factor.number=10 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/rating/gplsa-test.properties b/src/test/resources/model/collaborative/rating/gplsa-test.properties new file mode 100644 index 0000000..e3432b8 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/gplsa-test.properties @@ -0,0 +1,5 @@ +recommender.iteration.learnrate=0.01 +recommender.iterator.maximum=100 +recommender.recommender.smoothWeight=2 +recommender.recommender.isranking=false +recommender.topic.number = 10 diff --git a/src/test/resources/model/collaborative/rating/irrg-test.properties b/src/test/resources/model/collaborative/rating/irrg-test.properties new file mode 100644 index 0000000..8e1ed9e --- /dev/null +++ b/src/test/resources/model/collaborative/rating/irrg-test.properties @@ -0,0 +1,11 @@ +recommender.iterator.learnrate=0.001 +recommender.iterator.learnrate.maximum=10 + +recommender.iterator.maximum=200 +recommender.user.regularization=0.001 +recommender.item.regularization=0.001 +recommender.alpha=0.01 +recommender.factor.number=10 +recommender.learnrate.bolddriver=true +recommender.learnrate.decay=1.0 + diff --git a/src/test/resources/model/collaborative/rating/ldcc-test.properties b/src/test/resources/model/collaborative/rating/ldcc-test.properties new file mode 100644 index 0000000..4bff145 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/ldcc-test.properties @@ -0,0 +1,4 @@ +recommender.iteration.learnrate=0.01 +recommender.iterator.maximum=100 +recommender.pgm.burnin=10 +recommender.pgm.samplelag=10 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/rating/llorma-test.properties b/src/test/resources/model/collaborative/rating/llorma-test.properties new file mode 100644 index 0000000..3c05ad2 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/llorma-test.properties @@ -0,0 +1,18 @@ +recommender.global.factors.num=10 +recommender.global.iteration.learnrate=0.0005 +recommender.global.user.regularization=0.1 +recommender.global.item.regularization=0.1 +recommender.global.iteration.maximum=200 +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=200 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=6 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.model.num=55 +recommender.thread.count=1 + +recommender.init.mean=0.0 +recommender.init.std=0.01 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/rating/mfals-test.properties b/src/test/resources/model/collaborative/rating/mfals-test.properties new file mode 100644 index 0000000..5419fa1 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/mfals-test.properties @@ -0,0 +1,4 @@ +recommender.iterator.maximum=100 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 diff --git a/src/test/resources/model/collaborative/rating/nmf-test.properties b/src/test/resources/model/collaborative/rating/nmf-test.properties new file mode 100644 index 0000000..cf48b16 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/nmf-test.properties @@ -0,0 +1,4 @@ +recommender.iterator.maximum=10 +recommender.factor.number=100 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/rating/pmf-test.properties b/src/test/resources/model/collaborative/rating/pmf-test.properties new file mode 100644 index 0000000..6e47ae8 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/pmf-test.properties @@ -0,0 +1,8 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=70 +recommender.user.regularization=0.08 +recommender.item.regularization=0.08 +recommender.factor.number=6 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/rating/rbm-test.properties b/src/test/resources/model/collaborative/rating/rbm-test.properties new file mode 100644 index 0000000..a65a44a --- /dev/null +++ b/src/test/resources/model/collaborative/rating/rbm-test.properties @@ -0,0 +1,10 @@ +recommender.iterator.maximum=20 +recommender.factor.number=500 +recommender.epsilonw=0.01 +recommender.epsilonvb=0.01 +recommender.epsilonhb=0.01 +recommender.tstep=1 +recommender.momentum=0.1 +recommender.lamtaw=0.01 +recommender.lamtab=0.0 +recommender.predictiontype=mean \ No newline at end of file diff --git a/src/test/resources/model/collaborative/rating/rfrec-test.properties b/src/test/resources/model/collaborative/rating/rfrec-test.properties new file mode 100644 index 0000000..f222756 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/rfrec-test.properties @@ -0,0 +1,8 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=10 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/rating/svdpp-test.properties b/src/test/resources/model/collaborative/rating/svdpp-test.properties new file mode 100644 index 0000000..ad95684 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/svdpp-test.properties @@ -0,0 +1,9 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=13 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.impItem.regularization=0.001 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/collaborative/rating/urp-test.properties b/src/test/resources/model/collaborative/rating/urp-test.properties new file mode 100644 index 0000000..4bff145 --- /dev/null +++ b/src/test/resources/model/collaborative/rating/urp-test.properties @@ -0,0 +1,4 @@ +recommender.iteration.learnrate=0.01 +recommender.iterator.maximum=100 +recommender.pgm.burnin=10 +recommender.pgm.samplelag=10 \ No newline at end of file diff --git a/src/test/resources/model/collaborative/userknnranking-test.properties b/src/test/resources/model/collaborative/userknnranking-test.properties new file mode 100644 index 0000000..227c3ae --- /dev/null +++ b/src/test/resources/model/collaborative/userknnranking-test.properties @@ -0,0 +1,2 @@ +recommender.correlation.class=com.jstarcraft.ai.math.algorithm.correlation.similarity.JaccardIndexSimilarity +recommender.neighbors.knn.number=50 diff --git a/src/test/resources/model/collaborative/userknnrating-test.properties b/src/test/resources/model/collaborative/userknnrating-test.properties new file mode 100644 index 0000000..1b817ee --- /dev/null +++ b/src/test/resources/model/collaborative/userknnrating-test.properties @@ -0,0 +1,3 @@ +recommender.correlation.class=com.jstarcraft.ai.math.algorithm.correlation.similarity.PCCSimilarity +recommender.neighbors.knn.number=50 +recommender.filter.class=generic diff --git a/src/test/resources/model/content/efmranking-test.properties b/src/test/resources/model/content/efmranking-test.properties new file mode 100644 index 0000000..3f95e1e --- /dev/null +++ b/src/test/resources/model/content/efmranking-test.properties @@ -0,0 +1,14 @@ +recommender.iterator.maximum=200 +recommender.factor.number=10 +recommender.factor.explicit=5 +recommender.regularization.lambdax=1 +recommender.regularization.lambday=1 +recommender.regularization.lambdau=0.01 +recommender.regularization.lambdah=0.01 +recommender.regularization.lambdav=0.01 + +recommender.explain.flag=true +recommender.explain.userids=480 8517 550 +recommender.explain.numfeature=5 + + diff --git a/src/test/resources/model/content/efmrating-test.properties b/src/test/resources/model/content/efmrating-test.properties new file mode 100644 index 0000000..d2e6988 --- /dev/null +++ b/src/test/resources/model/content/efmrating-test.properties @@ -0,0 +1,17 @@ +recommender.iterator.maximum=50 +recommender.factor.number=10 +recommender.factor.explicit=5 +recommender.regularization.lambdax=1 +recommender.regularization.lambday=1 +recommender.regularization.lambdau=0.01 +recommender.regularization.lambdah=0.01 +recommender.regularization.lambdav=0.01 + +recommender.explain.flag=true +recommender.explain.userids=480 8517 550 +recommender.explain.numfeature=5 + +recommender.recommender.rating.minimum=1.0 +recommender.recommender.rating.maximum=5.0 + + diff --git a/src/test/resources/model/content/hft-test.properties b/src/test/resources/model/content/hft-test.properties new file mode 100644 index 0000000..a6b8d61 --- /dev/null +++ b/src/test/resources/model/content/hft-test.properties @@ -0,0 +1,18 @@ +# The training approach is SGD instead of L-BFGS, so it can be slow if the dataset +# is big. if you want a quick test, try the path : test/hfttest/musical_instruments.arff +# path of the full dataset is : test/hfttest/musical_instruments_full.arff + +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=2 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.recommender.lambda.user=0.05 +recommender.recommender.lambda.item=0.05 +recommender.bias.regularization=0.01 + +recommender.recommender.rating.minimum=1.0 +recommender.recommender.rating.maximum=5.0 \ No newline at end of file diff --git a/src/test/resources/model/content/tfidf-test.properties b/src/test/resources/model/content/tfidf-test.properties new file mode 100644 index 0000000..14ed3d3 --- /dev/null +++ b/src/test/resources/model/content/tfidf-test.properties @@ -0,0 +1 @@ +recommender.correlation.class=com.jstarcraft.ai.math.algorithm.correlation.similarity.CosineSimilarity diff --git a/src/test/resources/model/content/topicmfat-test.properties b/src/test/resources/model/content/topicmfat-test.properties new file mode 100644 index 0000000..85d33ee --- /dev/null +++ b/src/test/resources/model/content/topicmfat-test.properties @@ -0,0 +1,12 @@ +recommender.regularization.lambda=0.001 +recommender.regularization.lambdaU=0.001 +recommender.regularization.lambdaV=0.001 +recommender.regularization.lambdaB=0.001 +recommender.topic.number=10 +recommender.iterator.learnrate=0.01 +recommender.iterator.maximum=10 +recommender.init.mean=0.0 +recommender.init.std=0.01 + +recommender.recommender.rating.minimum=1.0 +recommender.recommender.rating.maximum=5.0 \ No newline at end of file diff --git a/src/test/resources/model/content/topicmfmt-test.properties b/src/test/resources/model/content/topicmfmt-test.properties new file mode 100644 index 0000000..85d33ee --- /dev/null +++ b/src/test/resources/model/content/topicmfmt-test.properties @@ -0,0 +1,12 @@ +recommender.regularization.lambda=0.001 +recommender.regularization.lambdaU=0.001 +recommender.regularization.lambdaV=0.001 +recommender.regularization.lambdaB=0.001 +recommender.topic.number=10 +recommender.iterator.learnrate=0.01 +recommender.iterator.maximum=10 +recommender.init.mean=0.0 +recommender.init.std=0.01 + +recommender.recommender.rating.minimum=1.0 +recommender.recommender.rating.maximum=5.0 \ No newline at end of file diff --git a/src/test/resources/model/context/ranking/rankgeofm-test.properties b/src/test/resources/model/context/ranking/rankgeofm-test.properties new file mode 100644 index 0000000..2d525e1 --- /dev/null +++ b/src/test/resources/model/context/ranking/rankgeofm-test.properties @@ -0,0 +1,10 @@ +recommender.factor.number=100 +recommender.iterator.learnrate=0.001 +recommender.iterator.learnrate.maximum=0.001 +recommender.iterator.maximum=200 +recommender.regularization.radius=1.0 +recommender.regularization.balance=0.2 +recommender.ranking.margin=0.3 +recommender.item.nearest.neighbour.number=300 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 \ No newline at end of file diff --git a/src/test/resources/model/context/ranking/sbpr-test.properties b/src/test/resources/model/context/ranking/sbpr-test.properties new file mode 100644 index 0000000..e7c0b50 --- /dev/null +++ b/src/test/resources/model/context/ranking/sbpr-test.properties @@ -0,0 +1,12 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=200 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.social.regularization=0.01 +recommender.bias.regularization=0.01 +recommender.factor.number=128 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.recommender.earlystop=false +recommender.recommender.verbose=true diff --git a/src/test/resources/model/context/rating/rste-test.properties b/src/test/resources/model/context/rating/rste-test.properties new file mode 100644 index 0000000..c8ad79b --- /dev/null +++ b/src/test/resources/model/context/rating/rste-test.properties @@ -0,0 +1,11 @@ +recommender.iterator.learnrate=0.02 +recommender.iterator.learnrate.maximum=0.02 +recommender.iterator.maximum=100 +recommender.user.regularization=0.001 +recommender.item.regularization=0.001 +recommender.user.social.ratio=1.0 +recommender.factor.number=5 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.recommender.earlystop=false +recommender.recommender.verbose=true diff --git a/src/test/resources/model/context/rating/socialmf-test.properties b/src/test/resources/model/context/rating/socialmf-test.properties new file mode 100644 index 0000000..593f35f --- /dev/null +++ b/src/test/resources/model/context/rating/socialmf-test.properties @@ -0,0 +1,11 @@ +recommender.iterator.learnrate=0.02 +recommender.iterator.learnrate.maximum=-1 +recommender.iterator.maximum=100 +recommender.user.regularization=0.001 +recommender.item.regularization=0.001 +recommender.social.regularization=1.0 +recommender.factor.number=5 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.recommender.earlystop=false +recommender.recommender.verbose=true diff --git a/src/test/resources/model/context/rating/sorec-test.properties b/src/test/resources/model/context/rating/sorec-test.properties new file mode 100644 index 0000000..885b397 --- /dev/null +++ b/src/test/resources/model/context/rating/sorec-test.properties @@ -0,0 +1,12 @@ +recommender.iterator.learnrate=0.02 +recommender.iterator.learnrate.maximum=-1 +recommender.iterator.maximum=100 +recommender.user.regularization=0.001 +recommender.item.regularization=0.001 +recommender.rate.social.regularization=0.01 +recommender.user.social.regularization=0.01 +recommender.factor.number=5 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.recommender.earlystop=false +recommender.recommender.verbose=true diff --git a/src/test/resources/model/context/rating/soreg-test.properties b/src/test/resources/model/context/rating/soreg-test.properties new file mode 100644 index 0000000..696e360 --- /dev/null +++ b/src/test/resources/model/context/rating/soreg-test.properties @@ -0,0 +1,12 @@ +recommender.correlation.class=com.jstarcraft.ai.math.algorithm.correlation.similarity.CosineSimilarity +recommender.iterator.learnrate=0.001 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=100 +recommender.user.regularization=0.001 +recommender.item.regularization=0.001 +recommender.social.regularization=0.1 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.recommender.earlystop=false +recommender.recommender.verbose=true diff --git a/src/test/resources/model/context/rating/timesvd-test.properties b/src/test/resources/model/context/rating/timesvd-test.properties new file mode 100644 index 0000000..59d3a6e --- /dev/null +++ b/src/test/resources/model/context/rating/timesvd-test.properties @@ -0,0 +1,6 @@ +recommender.iterator.learnrate=0.01 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=100 +recommender.user.regularization=0.01 +recommender.item.regularization=0.01 +recommender.learnrate.decay=1.0 diff --git a/src/test/resources/model/context/rating/trustmf-test.properties b/src/test/resources/model/context/rating/trustmf-test.properties new file mode 100644 index 0000000..baba00c --- /dev/null +++ b/src/test/resources/model/context/rating/trustmf-test.properties @@ -0,0 +1,12 @@ +recommender.iterator.learnrate=0.05 +recommender.iterator.learnrate.maximum=0.01 +recommender.iterator.maximum=200 +recommender.user.regularization=0.001 +recommender.item.regularization=0.001 +recommender.social.regularization=1.0 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.recommender.earlystop=false +recommender.recommender.verbose=true +recommender.social.model=T diff --git a/src/test/resources/model/context/rating/trustsvd-test.properties b/src/test/resources/model/context/rating/trustsvd-test.properties new file mode 100644 index 0000000..8fd562b --- /dev/null +++ b/src/test/resources/model/context/rating/trustsvd-test.properties @@ -0,0 +1,12 @@ +recommender.iterator.learnrate=0.005 +recommender.iterator.learnrate.maximum=-1 +recommender.iterator.maximum=100 +recommender.user.regularization=1.2 +recommender.item.regularization=1.2 +recommender.social.regularization=0.9 +recommender.bias.regularization=1.2 +recommender.factor.number=10 +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.recommender.earlystop=false +recommender.recommender.verbose=true diff --git a/src/test/resources/model/extend/associationrule-test.properties b/src/test/resources/model/extend/associationrule-test.properties new file mode 100644 index 0000000..e69de29 diff --git a/src/test/resources/model/extend/external-test.properties b/src/test/resources/model/extend/external-test.properties new file mode 100644 index 0000000..e69de29 diff --git a/src/test/resources/model/extend/personalitydiagnosis-test.properties b/src/test/resources/model/extend/personalitydiagnosis-test.properties new file mode 100644 index 0000000..126e24c --- /dev/null +++ b/src/test/resources/model/extend/personalitydiagnosis-test.properties @@ -0,0 +1 @@ +recommender.PersonalityDiagnosis.sigma=2.0 diff --git a/src/test/resources/model/extend/prankd-test.properties b/src/test/resources/model/extend/prankd-test.properties new file mode 100644 index 0000000..d890d4b --- /dev/null +++ b/src/test/resources/model/extend/prankd-test.properties @@ -0,0 +1,4 @@ +recommender.correlation.class=com.jstarcraft.ai.math.algorithm.correlation.similarity.CosineSimilarity +recommender.learnrate.bolddriver=false +recommender.learnrate.decay=1.0 +recommender.sim.filter=4.0 diff --git a/src/test/resources/model/extend/slopeone-test.properties b/src/test/resources/model/extend/slopeone-test.properties new file mode 100644 index 0000000..e69de29