From 15e264af65e08505ec85d8cb0dd53170ac044985 Mon Sep 17 00:00:00 2001 From: Jauhien Piatlicki Date: Fri, 17 Apr 2015 11:31:25 +0200 Subject: [g_sorcery/package_db] add category common data setter and getter to DB API --- g_sorcery/package_db.py | 38 ++++++++++++++++++++++++++++++++++++++ tests/test_PackageDB.py | 17 +++++++++++++---- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/g_sorcery/package_db.py b/g_sorcery/package_db.py index ec2d45f..5374ae5 100644 --- a/g_sorcery/package_db.py +++ b/g_sorcery/package_db.py @@ -105,6 +105,7 @@ class PackageDB(object): self.pkg_name, self.pkg_data = next(self.pkgs_iter) self.vers_iter = iter(self.pkg_data.items()) + ebuild_data = dict(ebuild_data) ebuild_data.update(self.cat_data['common_data']) return (Package(self.cat_name, self.pkg_name, ver), ebuild_data) @@ -129,6 +130,7 @@ class PackageDB(object): self.pkg_name, self.pkg_data = next(self.pkgs_iter) self.vers_iter = iter(self.pkg_data.items()) + ebuild_data = dict(ebuild_data) ebuild_data.update(self.cat_data['common_data']) return (Package(self.cat_name, self.pkg_name, ver), ebuild_data) @@ -275,6 +277,42 @@ class PackageDB(object): self.categories[category] = description + def set_common_data(self, category, common_data): + """ + Set common data for a category. + + Args: + category: Category name. + common_data: Category common data. + """ + if not category in self.categories: + raise InvalidKeyError('Non-existent category: ' + category) + + if not category in self.database: + self.database[category] = {'common_data': common_data, 'packages': {}} + else: + self.database[category]['common_data'] = common_data + + + def get_common_data(self, category): + """ + Get common data for a category. + + Args: + category: Category name. + + Returns: + Dictionary with category common data. + """ + if not category in self.categories: + raise InvalidKeyError('Non-existent category: ' + category) + + if not category in self.database: + return {} + else: + return self.database[category]['common_data'] + + def add_package(self, package, ebuild_data=None): """ Add a package. diff --git a/tests/test_PackageDB.py b/tests/test_PackageDB.py index f73f006..2a67385 100644 --- a/tests/test_PackageDB.py +++ b/tests/test_PackageDB.py @@ -38,11 +38,15 @@ class TestPackageDB(BaseTest): orig_db = PackageDB(orig_path) orig_db.add_category("app-test1") orig_db.add_category("app-test2") - ebuild_data = {"test1": "test1", "test2": "test2"} + ebuild_data = {"test1": "tst1", "test2": "tst2"} + common_data = {"common1": "cmn1", "common2": "cmn2"} packages = [Package("app-test1", "test", "1"), Package("app-test1", "test", "2"), Package("app-test1", "test1", "1"), Package("app-test2", "test2", "1")] for package in packages: orig_db.add_package(package, ebuild_data) + orig_db.set_common_data("app-test1", common_data) + full_data = dict(ebuild_data) + full_data.update(common_data) orig_db.write() os.system("cd " + orig_tempdir.name + " && tar cvzf good.tar.gz db") @@ -61,6 +65,8 @@ class TestPackageDB(BaseTest): srv.join() test_db.read() self.assertEqual(orig_db.database, test_db.database) + self.assertEqual(orig_db.get_common_data("app-test1"), test_db.get_common_data("app-test1")) + self.assertEqual(orig_db.get_common_data("app-test2"), test_db.get_common_data("app-test2")) self.assertEqual(set(test_db.list_categories()), set(["app-test1", "app-test2"])) self.assertTrue(test_db.in_category("app-test1", "test")) self.assertFalse(test_db.in_category("app-test2", "test")) @@ -71,7 +77,7 @@ class TestPackageDB(BaseTest): self.assertRaises(InvalidKeyError, test_db.list_package_versions, "app-test1", "invalid") self.assertEqual(set(test_db.list_package_versions("app-test1", "test")), set(['1', '2'])) self.assertEqual(set(test_db.list_all_packages()), set(packages)) - self.assertEqual(test_db.get_package_description(packages[0]), ebuild_data) + self.assertEqual(test_db.get_package_description(packages[0]), full_data) self.assertRaises(KeyError, test_db.get_package_description, Package("invalid", "invalid", "1")) self.assertEqual(test_db.get_max_version("app-test1", "test"), "2") self.assertEqual(test_db.get_max_version("app-test1", "test1"), "1") @@ -79,10 +85,13 @@ class TestPackageDB(BaseTest): pkg_set = set(packages) for package, data in test_db: self.assertTrue(package in pkg_set) - self.assertEqual(data, ebuild_data) + if package.category == "app-test1": + self.assertEqual(data, full_data) + else: + self.assertEqual(data, ebuild_data) pkg_set.remove(package) self.assertTrue(not pkg_set) - + self.assertEqual(orig_db.database, test_db.database) def suite(): suite = unittest.TestSuite() -- cgit v1.2.3-65-gdbad