Codebase list matrix-synapse / 7b07dc8
New upstream version 1.28.0 Andrej Shadura 3 years ago
335 changed file(s) with 7067 addition(s) and 3900 deletion(s). Raw diff Collapse all Expand all
1313 platforms: linux/amd64
1414 - docker_build:
1515 tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
16 platforms: linux/amd64,linux/arm/v7,linux/arm64
16 platforms: linux/amd64,linux/arm64
1717
1818 dockerhubuploadlatest:
1919 docker:
2626 # until all of the platforms are built.
2727 - docker_build:
2828 tag: -t matrixdotorg/synapse:latest
29 platforms: linux/amd64,linux/arm/v7,linux/arm64
29 platforms: linux/amd64,linux/arm64
3030
3131 workflows:
3232 build:
0 Synapse 1.28.0 (2021-02-25)
1 ===========================
2
3 Note that this release drops support for ARMv7 in the official Docker images, due to repeated problems building for ARMv7 (and the associated maintenance burden this entails).
4
5 This release also fixes the documentation included in v1.27.0 around the callback URI for SAML2 identity providers. If your server is configured to use single sign-on via a SAML2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
6
7
8 Internal Changes
9 ----------------
10
11 - Revert change in v1.28.0rc1 to remove the deprecated SAML endpoint. ([\#9474](https://github.com/matrix-org/synapse/issues/9474))
12
13
14 Synapse 1.28.0rc1 (2021-02-19)
15 ==============================
16
17 Removal warning
18 ---------------
19
20 The v1 list accounts API is deprecated and will be removed in a future release.
21 This API was undocumented and misleading. It can be replaced by the
22 [v2 list accounts API](https://github.com/matrix-org/synapse/blob/release-v1.28.0/docs/admin_api/user_admin_api.rst#list-accounts),
23 which has been available since Synapse 1.7.0 (2019-12-13).
24
25 Please check if you're using any scripts which use the admin API and replace
26 `GET /_synapse/admin/v1/users/<user_id>` with `GET /_synapse/admin/v2/users`.
27
28
29 Features
30 --------
31
32 - New admin API to get the context of an event: `/_synapse/admin/rooms/{roomId}/context/{eventId}`. ([\#9150](https://github.com/matrix-org/synapse/issues/9150))
33 - Further improvements to the user experience of registration via single sign-on. ([\#9300](https://github.com/matrix-org/synapse/issues/9300), [\#9301](https://github.com/matrix-org/synapse/issues/9301))
34 - Add hook to spam checker modules that allow checking file uploads and remote downloads. ([\#9311](https://github.com/matrix-org/synapse/issues/9311))
35 - Add support for receiving OpenID Connect authentication responses via form `POST`s rather than `GET`s. ([\#9376](https://github.com/matrix-org/synapse/issues/9376))
36 - Add the shadow-banning status to the admin API for user info. ([\#9400](https://github.com/matrix-org/synapse/issues/9400))
37
38
39 Bugfixes
40 --------
41
42 - Fix long-standing bug where sending email notifications would fail for rooms that the server had since left. ([\#9257](https://github.com/matrix-org/synapse/issues/9257))
43 - Fix bug introduced in Synapse 1.27.0rc1 which meant the "session expired" error page during SSO registration was badly formatted. ([\#9296](https://github.com/matrix-org/synapse/issues/9296))
44 - Assert a maximum length for some parameters for spec compliance. ([\#9321](https://github.com/matrix-org/synapse/issues/9321), [\#9393](https://github.com/matrix-org/synapse/issues/9393))
45 - Fix additional errors when previewing URLs: "AttributeError 'NoneType' object has no attribute 'xpath'" and "ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.". ([\#9333](https://github.com/matrix-org/synapse/issues/9333))
46 - Fix a bug causing Synapse to impose the wrong type constraints on fields when processing responses from appservices to `/_matrix/app/v1/thirdparty/user/{protocol}`. ([\#9361](https://github.com/matrix-org/synapse/issues/9361))
47 - Fix bug where Synapse would occasionally stop reconnecting to Redis after the connection was lost. ([\#9391](https://github.com/matrix-org/synapse/issues/9391))
48 - Fix a long-standing bug when upgrading a room: "TypeError: '>' not supported between instances of 'NoneType' and 'int'". ([\#9395](https://github.com/matrix-org/synapse/issues/9395))
49 - Reduce the amount of memory used when generating the URL preview of a file that is larger than the `max_spider_size`. ([\#9421](https://github.com/matrix-org/synapse/issues/9421))
50 - Fix a long-standing bug in the deduplication of old presence, resulting in no deduplication. ([\#9425](https://github.com/matrix-org/synapse/issues/9425))
51 - The `ui_auth.session_timeout` config option can now be specified in terms of number of seconds/minutes/etc/. Contributed by Rishabh Arya. ([\#9426](https://github.com/matrix-org/synapse/issues/9426))
52 - Fix a bug introduced in v1.27.0: "TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType." related to the user directory. ([\#9428](https://github.com/matrix-org/synapse/issues/9428))
53
54
55 Updates to the Docker image
56 ---------------------------
57
58 - Drop support for ARMv7 in Docker images. ([\#9433](https://github.com/matrix-org/synapse/issues/9433))
59
60
61 Improved Documentation
62 ----------------------
63
64 - Reorganize CHANGELOG.md. ([\#9281](https://github.com/matrix-org/synapse/issues/9281))
65 - Add note to `auto_join_rooms` config option explaining existing rooms must be publicly joinable. ([\#9291](https://github.com/matrix-org/synapse/issues/9291))
66 - Correct name of Synapse's service file in TURN howto. ([\#9308](https://github.com/matrix-org/synapse/issues/9308))
67 - Fix the braces in the `oidc_providers` section of the sample config. ([\#9317](https://github.com/matrix-org/synapse/issues/9317))
68 - Update installation instructions on Fedora. ([\#9322](https://github.com/matrix-org/synapse/issues/9322))
69 - Add HTTP/2 support to the nginx example configuration. Contributed by David Vo. ([\#9390](https://github.com/matrix-org/synapse/issues/9390))
70 - Update docs for using Gitea as OpenID provider. ([\#9404](https://github.com/matrix-org/synapse/issues/9404))
71 - Document that pusher instances are shardable. ([\#9407](https://github.com/matrix-org/synapse/issues/9407))
72 - Fix erroneous documentation from v1.27.0 about updating the SAML2 callback URL. ([\#9434](https://github.com/matrix-org/synapse/issues/9434))
73
74
75 Deprecations and Removals
76 -------------------------
77
78 - Deprecate old admin API `GET /_synapse/admin/v1/users/<user_id>`. ([\#9429](https://github.com/matrix-org/synapse/issues/9429))
79
80
81 Internal Changes
82 ----------------
83
84 - Fix 'object name reserved for internal use' errors with recent versions of SQLite. ([\#9003](https://github.com/matrix-org/synapse/issues/9003))
85 - Add experimental support for running Synapse with PyPy. ([\#9123](https://github.com/matrix-org/synapse/issues/9123))
86 - Deny access to additional IP addresses by default. ([\#9240](https://github.com/matrix-org/synapse/issues/9240))
87 - Update the `Cursor` type hints to better match PEP 249. ([\#9299](https://github.com/matrix-org/synapse/issues/9299))
88 - Add debug logging for SRV lookups. Contributed by @Bubu. ([\#9305](https://github.com/matrix-org/synapse/issues/9305))
89 - Improve logging for OIDC login flow. ([\#9307](https://github.com/matrix-org/synapse/issues/9307))
90 - Share the code for handling required attributes between the CAS and SAML handlers. ([\#9326](https://github.com/matrix-org/synapse/issues/9326))
91 - Clean up the code to load the metadata for OpenID Connect identity providers. ([\#9362](https://github.com/matrix-org/synapse/issues/9362))
92 - Convert tests to use `HomeserverTestCase`. ([\#9377](https://github.com/matrix-org/synapse/issues/9377), [\#9396](https://github.com/matrix-org/synapse/issues/9396))
93 - Update the version of black used to 20.8b1. ([\#9381](https://github.com/matrix-org/synapse/issues/9381))
94 - Allow OIDC config to override discovered values. ([\#9384](https://github.com/matrix-org/synapse/issues/9384))
95 - Remove some dead code from the acceptance of room invites path. ([\#9394](https://github.com/matrix-org/synapse/issues/9394))
96 - Clean up an unused method in the presence handler code. ([\#9408](https://github.com/matrix-org/synapse/issues/9408))
97
98
099 Synapse 1.27.0 (2021-02-16)
1100 ===========================
2101
3102 Note that this release includes a change in Synapse to use Redis as a cache ─ as well as a pub/sub mechanism ─ if Redis support is enabled for workers. No action is needed by server administrators, and we do not expect resource usage of the Redis instance to change dramatically.
4103
5 This release also changes the callback URI for OpenID Connect (OIDC) identity providers. If your server is configured to use single sign-on via an OIDC/OAuth2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
104 This release also changes the callback URI for OpenID Connect (OIDC) and SAML2 identity providers. If your server is configured to use single sign-on via an OIDC/OAuth2 or SAML2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
6105
7106 This release also changes escaping of variables in the HTML templates for SSO or email notifications. If you have customised these templates, please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
8107
0 # Contributing code to Synapse
0 Welcome to Synapse
1
2 This document aims to get you started with contributing to this repo!
3
4 - [1. Who can contribute to Synapse?](#1-who-can-contribute-to-synapse)
5 - [2. What do I need?](#2-what-do-i-need)
6 - [3. Get the source.](#3-get-the-source)
7 - [4. Install the dependencies](#4-install-the-dependencies)
8 * [Under Unix (macOS, Linux, BSD, ...)](#under-unix-macos-linux-bsd-)
9 * [Under Windows](#under-windows)
10 - [5. Get in touch.](#5-get-in-touch)
11 - [6. Pick an issue.](#6-pick-an-issue)
12 - [7. Turn coffee and documentation into code and documentation!](#7-turn-coffee-and-documentation-into-code-and-documentation)
13 - [8. Test, test, test!](#8-test-test-test)
14 * [Run the linters.](#run-the-linters)
15 * [Run the unit tests.](#run-the-unit-tests)
16 * [Run the integration tests.](#run-the-integration-tests)
17 - [9. Submit your patch.](#9-submit-your-patch)
18 * [Changelog](#changelog)
19 + [How do I know what to call the changelog file before I create the PR?](#how-do-i-know-what-to-call-the-changelog-file-before-i-create-the-pr)
20 + [Debian changelog](#debian-changelog)
21 * [Sign off](#sign-off)
22 - [10. Turn feedback into better code.](#10-turn-feedback-into-better-code)
23 - [11. Find a new issue.](#11-find-a-new-issue)
24 - [Notes for maintainers on merging PRs etc](#notes-for-maintainers-on-merging-prs-etc)
25 - [Conclusion](#conclusion)
26
27 # 1. Who can contribute to Synapse?
128
229 Everyone is welcome to contribute code to [matrix.org
330 projects](https://github.com/matrix-org), provided that they are willing to
835 license - in our case, this is almost always Apache Software License v2 (see
936 [LICENSE](LICENSE)).
1037
11 ## How to contribute
38 # 2. What do I need?
39
40 The code of Synapse is written in Python 3. To do pretty much anything, you'll need [a recent version of Python 3](https://wiki.python.org/moin/BeginnersGuide/Download).
41
42 The source code of Synapse is hosted on GitHub. You will also need [a recent version of git](https://github.com/git-guides/install-git).
43
44 For some tests, you will need [a recent version of Docker](https://docs.docker.com/get-docker/).
45
46
47 # 3. Get the source.
1248
1349 The preferred and easiest way to contribute changes is to fork the relevant
14 project on github, and then [create a pull request](
50 project on GitHub, and then [create a pull request](
1551 https://help.github.com/articles/using-pull-requests/) to ask us to pull your
1652 changes into our repo.
1753
18 Some other points to follow:
19
20 * Please base your changes on the `develop` branch.
21
22 * Please follow the [code style requirements](#code-style).
23
24 * Please include a [changelog entry](#changelog) with each PR.
25
26 * Please [sign off](#sign-off) your contribution.
27
28 * Please keep an eye on the pull request for feedback from the [continuous
29 integration system](#continuous-integration-and-testing) and try to fix any
30 errors that come up.
31
32 * If you need to [update your PR](#updating-your-pull-request), just add new
33 commits to your branch rather than rebasing.
34
35 ## Code style
54 Please base your changes on the `develop` branch.
55
56 ```sh
57 git clone git@github.com:YOUR_GITHUB_USER_NAME/synapse.git
58 git checkout develop
59 ```
60
61 If you need help getting started with git, this is beyond the scope of the document, but you
62 can find many good git tutorials on the web.
63
64 # 4. Install the dependencies
65
66 ## Under Unix (macOS, Linux, BSD, ...)
67
68 Once you have installed Python 3 and added the source, please open a terminal and
69 setup a *virtualenv*, as follows:
70
71 ```sh
72 cd path/where/you/have/cloned/the/repository
73 python3 -m venv ./env
74 source ./env/bin/activate
75 pip install -e ".[all,lint,mypy,test]"
76 pip install tox
77 ```
78
79 This will install the developer dependencies for the project.
80
81 ## Under Windows
82
83 TBD
84
85
86 # 5. Get in touch.
87
88 Join our developer community on Matrix: #synapse-dev:matrix.org !
89
90
91 # 6. Pick an issue.
92
93 Fix your favorite problem or perhaps find a [Good First Issue](https://github.com/matrix-org/synapse/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+First+Issue%22)
94 to work on.
95
96
97 # 7. Turn coffee and documentation into code and documentation!
3698
3799 Synapse's code style is documented [here](docs/code_style.md). Please follow
38100 it, including the conventions for the [sample configuration
39101 file](docs/code_style.md#configuration-file-format).
40
41 Many of the conventions are enforced by scripts which are run as part of the
42 [continuous integration system](#continuous-integration-and-testing). To help
43 check if you have followed the code style, you can run `scripts-dev/lint.sh`
44 locally. You'll need python 3.6 or later, and to install a number of tools:
45
46 ```
47 # Install the dependencies
48 pip install -e ".[lint,mypy]"
49
50 # Run the linter script
51 ./scripts-dev/lint.sh
52 ```
53
54 **Note that the script does not just test/check, but also reformats code, so you
55 may wish to ensure any new code is committed first**.
56
57 By default, this script checks all files and can take some time; if you alter
58 only certain files, you might wish to specify paths as arguments to reduce the
59 run-time:
60
61 ```
62 ./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
63 ```
64
65 You can also provide the `-d` option, which will lint the files that have been
66 changed since the last git commit. This will often be significantly faster than
67 linting the whole codebase.
68
69 Before pushing new changes, ensure they don't produce linting errors. Commit any
70 files that were corrected.
71
72 Please ensure your changes match the cosmetic style of the existing project,
73 and **never** mix cosmetic and functional changes in the same commit, as it
74 makes it horribly hard to review otherwise.
75
76 ## Changelog
77
78 All changes, even minor ones, need a corresponding changelog / newsfragment
79 entry. These are managed by [Towncrier](https://github.com/hawkowl/towncrier).
80
81 To create a changelog entry, make a new file in the `changelog.d` directory named
82 in the format of `PRnumber.type`. The type can be one of the following:
83
84 * `feature`
85 * `bugfix`
86 * `docker` (for updates to the Docker image)
87 * `doc` (for updates to the documentation)
88 * `removal` (also used for deprecations)
89 * `misc` (for internal-only changes)
90
91 This file will become part of our [changelog](
92 https://github.com/matrix-org/synapse/blob/master/CHANGES.md) at the next
93 release, so the content of the file should be a short description of your
94 change in the same style as the rest of the changelog. The file can contain Markdown
95 formatting, and should end with a full stop (.) or an exclamation mark (!) for
96 consistency.
97
98 Adding credits to the changelog is encouraged, we value your
99 contributions and would like to have you shouted out in the release notes!
100
101 For example, a fix in PR #1234 would have its changelog entry in
102 `changelog.d/1234.bugfix`, and contain content like:
103
104 > The security levels of Florbs are now validated when received
105 > via the `/federation/florb` endpoint. Contributed by Jane Matrix.
106
107 If there are multiple pull requests involved in a single bugfix/feature/etc,
108 then the content for each `changelog.d` file should be the same. Towncrier will
109 merge the matching files together into a single changelog entry when we come to
110 release.
111
112 ### How do I know what to call the changelog file before I create the PR?
113
114 Obviously, you don't know if you should call your newsfile
115 `1234.bugfix` or `5678.bugfix` until you create the PR, which leads to a
116 chicken-and-egg problem.
117
118 There are two options for solving this:
119
120 1. Open the PR without a changelog file, see what number you got, and *then*
121 add the changelog file to your branch (see [Updating your pull
122 request](#updating-your-pull-request)), or:
123
124 1. Look at the [list of all
125 issues/PRs](https://github.com/matrix-org/synapse/issues?q=), add one to the
126 highest number you see, and quickly open the PR before somebody else claims
127 your number.
128
129 [This
130 script](https://github.com/richvdh/scripts/blob/master/next_github_number.sh)
131 might be helpful if you find yourself doing this a lot.
132
133 Sorry, we know it's a bit fiddly, but it's *really* helpful for us when we come
134 to put together a release!
135
136 ### Debian changelog
137
138 Changes which affect the debian packaging files (in `debian`) are an
139 exception to the rule that all changes require a `changelog.d` file.
140
141 In this case, you will need to add an entry to the debian changelog for the
142 next release. For this, run the following command:
143
144 ```
145 dch
146 ```
147
148 This will make up a new version number (if there isn't already an unreleased
149 version in flight), and open an editor where you can add a new changelog entry.
150 (Our release process will ensure that the version number and maintainer name is
151 corrected for the release.)
152
153 If your change affects both the debian packaging *and* files outside the debian
154 directory, you will need both a regular newsfragment *and* an entry in the
155 debian changelog. (Though typically such changes should be submitted as two
156 separate pull requests.)
157
158 ## Documentation
159102
160103 There is a growing amount of documentation located in the [docs](docs)
161104 directory. This documentation is intended primarily for sysadmins running their
165108 regarding Synapse's Admin API, which is used mostly by sysadmins and external
166109 service developers.
167110
168 New files added to both folders should be written in [Github-Flavoured
169 Markdown](https://guides.github.com/features/mastering-markdown/), and attempts
170 should be made to migrate existing documents to markdown where possible.
171
172 Some documentation also exists in [Synapse's Github
111 If you add new files added to either of these folders, please use [GitHub-Flavoured
112 Markdown](https://guides.github.com/features/mastering-markdown/).
113
114 Some documentation also exists in [Synapse's GitHub
173115 Wiki](https://github.com/matrix-org/synapse/wiki), although this is primarily
174116 contributed to by community authors.
117
118
119 # 8. Test, test, test!
120 <a name="test-test-test"></a>
121
122 While you're developing and before submitting a patch, you'll
123 want to test your code.
124
125 ## Run the linters.
126
127 The linters look at your code and do two things:
128
129 - ensure that your code follows the coding style adopted by the project;
130 - catch a number of errors in your code.
131
132 They're pretty fast, don't hesitate!
133
134 ```sh
135 source ./env/bin/activate
136 ./scripts-dev/lint.sh
137 ```
138
139 Note that this script *will modify your files* to fix styling errors.
140 Make sure that you have saved all your files.
141
142 If you wish to restrict the linters to only the files changed since the last commit
143 (much faster!), you can instead run:
144
145 ```sh
146 source ./env/bin/activate
147 ./scripts-dev/lint.sh -d
148 ```
149
150 Or if you know exactly which files you wish to lint, you can instead run:
151
152 ```sh
153 source ./env/bin/activate
154 ./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
155 ```
156
157 ## Run the unit tests.
158
159 The unit tests run parts of Synapse, including your changes, to see if anything
160 was broken. They are slower than the linters but will typically catch more errors.
161
162 ```sh
163 source ./env/bin/activate
164 trial tests
165 ```
166
167 If you wish to only run *some* unit tests, you may specify
168 another module instead of `tests` - or a test class or a method:
169
170 ```sh
171 source ./env/bin/activate
172 trial tests.rest.admin.test_room tests.handlers.test_admin.ExfiltrateData.test_invite
173 ```
174
175 If your tests fail, you may wish to look at the logs:
176
177 ```sh
178 less _trial_temp/test.log
179 ```
180
181 ## Run the integration tests.
182
183 The integration tests are a more comprehensive suite of tests. They
184 run a full version of Synapse, including your changes, to check if
185 anything was broken. They are slower than the unit tests but will
186 typically catch more errors.
187
188 The following command will let you run the integration test with the most common
189 configuration:
190
191 ```sh
192 $ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:py37
193 ```
194
195 This configuration should generally cover your needs. For more details about other configurations, see [documentation in the SyTest repo](https://github.com/matrix-org/sytest/blob/develop/docker/README.md).
196
197
198 # 9. Submit your patch.
199
200 Once you're happy with your patch, it's time to prepare a Pull Request.
201
202 To prepare a Pull Request, please:
203
204 1. verify that [all the tests pass](#test-test-test), including the coding style;
205 2. [sign off](#sign-off) your contribution;
206 3. `git push` your commit to your fork of Synapse;
207 4. on GitHub, [create the Pull Request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request);
208 5. add a [changelog entry](#changelog) and push it to your Pull Request;
209 6. for most contributors, that's all - however, if you are a member of the organization `matrix-org`, on GitHub, please request a review from `matrix.org / Synapse Core`.
210
211
212 ## Changelog
213
214 All changes, even minor ones, need a corresponding changelog / newsfragment
215 entry. These are managed by [Towncrier](https://github.com/hawkowl/towncrier).
216
217 To create a changelog entry, make a new file in the `changelog.d` directory named
218 in the format of `PRnumber.type`. The type can be one of the following:
219
220 * `feature`
221 * `bugfix`
222 * `docker` (for updates to the Docker image)
223 * `doc` (for updates to the documentation)
224 * `removal` (also used for deprecations)
225 * `misc` (for internal-only changes)
226
227 This file will become part of our [changelog](
228 https://github.com/matrix-org/synapse/blob/master/CHANGES.md) at the next
229 release, so the content of the file should be a short description of your
230 change in the same style as the rest of the changelog. The file can contain Markdown
231 formatting, and should end with a full stop (.) or an exclamation mark (!) for
232 consistency.
233
234 Adding credits to the changelog is encouraged, we value your
235 contributions and would like to have you shouted out in the release notes!
236
237 For example, a fix in PR #1234 would have its changelog entry in
238 `changelog.d/1234.bugfix`, and contain content like:
239
240 > The security levels of Florbs are now validated when received
241 > via the `/federation/florb` endpoint. Contributed by Jane Matrix.
242
243 If there are multiple pull requests involved in a single bugfix/feature/etc,
244 then the content for each `changelog.d` file should be the same. Towncrier will
245 merge the matching files together into a single changelog entry when we come to
246 release.
247
248 ### How do I know what to call the changelog file before I create the PR?
249
250 Obviously, you don't know if you should call your newsfile
251 `1234.bugfix` or `5678.bugfix` until you create the PR, which leads to a
252 chicken-and-egg problem.
253
254 There are two options for solving this:
255
256 1. Open the PR without a changelog file, see what number you got, and *then*
257 add the changelog file to your branch (see [Updating your pull
258 request](#updating-your-pull-request)), or:
259
260 1. Look at the [list of all
261 issues/PRs](https://github.com/matrix-org/synapse/issues?q=), add one to the
262 highest number you see, and quickly open the PR before somebody else claims
263 your number.
264
265 [This
266 script](https://github.com/richvdh/scripts/blob/master/next_github_number.sh)
267 might be helpful if you find yourself doing this a lot.
268
269 Sorry, we know it's a bit fiddly, but it's *really* helpful for us when we come
270 to put together a release!
271
272 ### Debian changelog
273
274 Changes which affect the debian packaging files (in `debian`) are an
275 exception to the rule that all changes require a `changelog.d` file.
276
277 In this case, you will need to add an entry to the debian changelog for the
278 next release. For this, run the following command:
279
280 ```
281 dch
282 ```
283
284 This will make up a new version number (if there isn't already an unreleased
285 version in flight), and open an editor where you can add a new changelog entry.
286 (Our release process will ensure that the version number and maintainer name is
287 corrected for the release.)
288
289 If your change affects both the debian packaging *and* files outside the debian
290 directory, you will need both a regular newsfragment *and* an entry in the
291 debian changelog. (Though typically such changes should be submitted as two
292 separate pull requests.)
175293
176294 ## Sign off
177295
239357 flag to `git commit`, which uses the name and email set in your
240358 `user.name` and `user.email` git configs.
241359
242 ## Continuous integration and testing
243
244 [Buildkite](https://buildkite.com/matrix-dot-org/synapse) will automatically
245 run a series of checks and tests against any PR which is opened against the
246 project; if your change breaks the build, this will be shown in GitHub, with
247 links to the build results. If your build fails, please try to fix the errors
248 and update your branch.
249
250 To run unit tests in a local development environment, you can use:
251
252 - ``tox -e py35`` (requires tox to be installed by ``pip install tox``)
253 for SQLite-backed Synapse on Python 3.5.
254 - ``tox -e py36`` for SQLite-backed Synapse on Python 3.6.
255 - ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6
256 (requires a running local PostgreSQL with access to create databases).
257 - ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5
258 (requires Docker). Entirely self-contained, recommended if you don't want to
259 set up PostgreSQL yourself.
260
261 Docker images are available for running the integration tests (SyTest) locally,
262 see the [documentation in the SyTest repo](
263 https://github.com/matrix-org/sytest/blob/develop/docker/README.md) for more
264 information.
265
266 ## Updating your pull request
267
268 If you decide to make changes to your pull request - perhaps to address issues
269 raised in a review, or to fix problems highlighted by [continuous
270 integration](#continuous-integration-and-testing) - just add new commits to your
271 branch, and push to GitHub. The pull request will automatically be updated.
272
273 Please **avoid** rebasing your branch, especially once the PR has been
274 reviewed: doing so makes it very difficult for a reviewer to see what has
275 changed since a previous review.
276
277 ## Notes for maintainers on merging PRs etc
360
361 # 10. Turn feedback into better code.
362
363 Once the Pull Request is opened, you will see a few things:
364
365 1. our automated CI (Continuous Integration) pipeline will run (again) the linters, the unit tests, the integration tests and more;
366 2. one or more of the developers will take a look at your Pull Request and offer feedback.
367
368 From this point, you should:
369
370 1. Look at the results of the CI pipeline.
371 - If there is any error, fix the error.
372 2. If a developer has requested changes, make these changes and let us know if it is ready for a developer to review again.
373 3. Create a new commit with the changes.
374 - Please do NOT overwrite the history. New commits make the reviewer's life easier.
375 - Push this commits to your Pull Request.
376 4. Back to 1.
377
378 Once both the CI and the developers are happy, the patch will be merged into Synapse and released shortly!
379
380 # 11. Find a new issue.
381
382 By now, you know the drill!
383
384 # Notes for maintainers on merging PRs etc
278385
279386 There are some notes for those with commit access to the project on how we
280387 manage git [here](docs/dev/git.md).
281388
282 ## Conclusion
389 # Conclusion
283390
284391 That's it! Matrix is a very open and collaborative project as you might expect
285392 given our obsession with open communication. If we're going to successfully
150150
151151 ##### CentOS/Fedora
152152
153 Installing prerequisites on CentOS 8 or Fedora>26:
153 Installing prerequisites on CentOS or Fedora Linux:
154154
155155 ```sh
156156 sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
157 libwebp-devel tk-devel redhat-rpm-config \
158 python3-virtualenv libffi-devel openssl-devel
157 libwebp-devel libxml2-devel libxslt-devel libpq-devel \
158 python3-virtualenv libffi-devel openssl-devel python3-devel
159159 sudo dnf groupinstall "Development Tools"
160160 ```
161
162 Installing prerequisites on CentOS 7 or Fedora<=25:
163
164 ```sh
165 sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
166 lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \
167 python3-virtualenv libffi-devel openssl-devel
168 sudo yum groupinstall "Development Tools"
169 ```
170
171 Note that Synapse does not support versions of SQLite before 3.11, and CentOS 7
172 uses SQLite 3.7. You may be able to work around this by installing a more
173 recent SQLite version, but it is recommended that you instead use a Postgres
174 database: see [docs/postgres.md](docs/postgres.md).
175161
176162 ##### macOS
177163
8787 Upgrading to v1.27.0
8888 ====================
8989
90 Changes to callback URI for OAuth2 / OpenID Connect
91 ---------------------------------------------------
92
93 This version changes the URI used for callbacks from OAuth2 identity providers. If
94 your server is configured for single sign-on via an OpenID Connect or OAuth2 identity
95 provider, you will need to add ``[synapse public baseurl]/_synapse/client/oidc/callback``
96 to the list of permitted "redirect URIs" at the identity provider.
97
98 See `docs/openid.md <docs/openid.md>`_ for more information on setting up OpenID
99 Connect.
100
101 (Note: a similar change is being made for SAML2; in this case the old URI
102 ``[synapse public baseurl]/_matrix/saml2`` is being deprecated, but will continue to
103 work, so no immediate changes are required for existing installations.)
90 Changes to callback URI for OAuth2 / OpenID Connect and SAML2
91 -------------------------------------------------------------
92
93 This version changes the URI used for callbacks from OAuth2 and SAML2 identity providers:
94
95 * If your server is configured for single sign-on via an OpenID Connect or OAuth2 identity
96 provider, you will need to add ``[synapse public baseurl]/_synapse/client/oidc/callback``
97 to the list of permitted "redirect URIs" at the identity provider.
98
99 See `docs/openid.md <docs/openid.md>`_ for more information on setting up OpenID
100 Connect.
101
102 * If your server is configured for single sign-on via a SAML2 identity provider, you will
103 need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted
104 "ACS location" (also known as "allowed callback URLs") at the identity provider.
104105
105106 Changes to HTML templates
106107 -------------------------
9191 return self.config["user"].split(":")[1]
9292
9393 def do_config(self, line):
94 """ Show the config for this client: "config"
94 """Show the config for this client: "config"
9595 Edit a key value mapping: "config key value" e.g. "config token 1234"
9696 Config variables:
9797 user: The username to auth with.
359359 print(e)
360360
361361 def do_topic(self, line):
362 """"topic [set|get] <roomid> [<newtopic>]"
362 """ "topic [set|get] <roomid> [<newtopic>]"
363363 Set the topic for a room: topic set <roomid> <newtopic>
364364 Get the topic for a room: topic get <roomid>
365365 """
689689 self._do_presence_state(2, line)
690690
691691 def _parse(self, line, keys, force_keys=False):
692 """ Parses the given line.
692 """Parses the given line.
693693
694694 Args:
695695 line : The line to parse
720720 query_params={"access_token": None},
721721 alt_text=None,
722722 ):
723 """ Runs an HTTP request and pretty prints the output.
723 """Runs an HTTP request and pretty prints the output.
724724
725725 Args:
726726 method: HTTP method
2222
2323
2424 class HttpClient:
25 """ Interface for talking json over http
26 """
25 """Interface for talking json over http"""
2726
2827 def put_json(self, url, data):
29 """ Sends the specifed json data using PUT
28 """Sends the specifed json data using PUT
3029
3130 Args:
3231 url (str): The URL to PUT data to.
4039 pass
4140
4241 def get_json(self, url, args=None):
43 """ Gets some json from the given host homeserver and path
42 """Gets some json from the given host homeserver and path
4443
4544 Args:
4645 url (str): The URL to GET data from.
5756
5857
5958 class TwistedHttpClient(HttpClient):
60 """ Wrapper around the twisted HTTP client api.
59 """Wrapper around the twisted HTTP client api.
6160
6261 Attributes:
6362 agent (twisted.web.client.Agent): The twisted Agent used to send the
8685 defer.returnValue(json.loads(body))
8786
8887 def _create_put_request(self, url, json_data, headers_dict={}):
89 """ Wrapper of _create_request to issue a PUT request
90 """
88 """Wrapper of _create_request to issue a PUT request"""
9189
9290 if "Content-Type" not in headers_dict:
9391 raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
9795 )
9896
9997 def _create_get_request(self, url, headers_dict={}):
100 """ Wrapper of _create_request to issue a GET request
101 """
98 """Wrapper of _create_request to issue a GET request"""
10299 return self._create_request("GET", url, headers_dict=headers_dict)
103100
104101 @defer.inlineCallbacks
126123
127124 @defer.inlineCallbacks
128125 def _create_request(self, method, url, producer=None, headers_dict={}):
129 """ Creates and sends a request to the given url
130 """
126 """Creates and sends a request to the given url"""
131127 headers_dict["User-Agent"] = ["Synapse Cmd Client"]
132128
133129 retries_left = 5
184180
185181
186182 class _JsonProducer:
187 """ Used by the twisted http client to create the HTTP body from json
188 """
183 """Used by the twisted http client to create the HTTP body from json"""
189184
190185 def __init__(self, jsn):
191186 self.data = jsn
6262 self.redraw()
6363
6464 def redraw(self):
65 """ method for redisplaying lines
66 based on internal list of lines """
65 """method for redisplaying lines based on internal list of lines"""
6766
6867 self.stdscr.clear()
6968 self.paintStatus(self.statusText)
5555
5656
5757 class InputOutput:
58 """ This is responsible for basic I/O so that a user can interact with
58 """This is responsible for basic I/O so that a user can interact with
5959 the example app.
6060 """
6161
6767 self.server = server
6868
6969 def on_line(self, line):
70 """ This is where we process commands.
71 """
70 """This is where we process commands."""
7271
7372 try:
7473 m = re.match(r"^join (\S+)$", line)
132131
133132
134133 class Room:
135 """ Used to store (in memory) the current membership state of a room, and
134 """Used to store (in memory) the current membership state of a room, and
136135 which home servers we should send PDUs associated with the room to.
137136 """
138137
147146 self.have_got_metadata = False
148147
149148 def add_participant(self, participant):
150 """ Someone has joined the room
151 """
149 """Someone has joined the room"""
152150 self.participants.add(participant)
153151 self.invited.discard(participant)
154152
159157 self.oldest_server = server
160158
161159 def add_invited(self, invitee):
162 """ Someone has been invited to the room
163 """
160 """Someone has been invited to the room"""
164161 self.invited.add(invitee)
165162 self.servers.add(origin_from_ucid(invitee))
166163
167164
168165 class HomeServer(ReplicationHandler):
169 """ A very basic home server implentation that allows people to join a
166 """A very basic home server implentation that allows people to join a
170167 room and then invite other people.
171168 """
172169
180177 self.output = output
181178
182179 def on_receive_pdu(self, pdu):
183 """ We just received a PDU
184 """
180 """We just received a PDU"""
185181 pdu_type = pdu.pdu_type
186182
187183 if pdu_type == "sy.room.message":
198194 )
199195
200196 def _on_message(self, pdu):
201 """ We received a message
202 """
197 """We received a message"""
203198 self.output.print_line(
204199 "#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"])
205200 )
206201
207202 def _on_join(self, context, joinee):
208 """ Someone has joined a room, either a remote user or a local user
209 """
203 """Someone has joined a room, either a remote user or a local user"""
210204 room = self._get_or_create_room(context)
211205 room.add_participant(joinee)
212206
213207 self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED"))
214208
215209 def _on_invite(self, origin, context, invitee):
216 """ Someone has been invited
217 """
210 """Someone has been invited"""
218211 room = self._get_or_create_room(context)
219212 room.add_invited(invitee)
220213
227220
228221 @defer.inlineCallbacks
229222 def send_message(self, room_name, sender, body):
230 """ Send a message to a room!
231 """
223 """Send a message to a room!"""
232224 destinations = yield self.get_servers_for_context(room_name)
233225
234226 try:
246238
247239 @defer.inlineCallbacks
248240 def join_room(self, room_name, sender, joinee):
249 """ Join a room!
250 """
241 """Join a room!"""
251242 self._on_join(room_name, joinee)
252243
253244 destinations = yield self.get_servers_for_context(room_name)
268259
269260 @defer.inlineCallbacks
270261 def invite_to_room(self, room_name, sender, invitee):
271 """ Invite someone to a room!
272 """
262 """Invite someone to a room!"""
273263 self._on_invite(self.server_name, room_name, invitee)
274264
275265 destinations = yield self.get_servers_for_context(room_name)
192192 time.sleep(7)
193193 print("SSRC spammer started")
194194 while self.running:
195 ssrcMsg = (
196 "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>"
197 % {
198 "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid),
199 "nick": self.userId,
200 "assrc": self.ssrcs["audio"],
201 "vssrc": self.ssrcs["video"],
202 }
203 )
195 ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % {
196 "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid),
197 "nick": self.userId,
198 "assrc": self.ssrcs["audio"],
199 "vssrc": self.ssrcs["video"],
200 }
204201 res = self.sendIq(ssrcMsg)
205202 print("reply from ssrc announce: ", res)
206203 time.sleep(10)
0 matrix-synapse-py3 (1.28.0) stable; urgency=medium
1
2 * New synapse release 1.28.0.
3
4 -- Synapse Packaging team <packages@matrix.org> Thu, 25 Feb 2021 10:21:57 +0000
5
06 matrix-synapse-py3 (1.27.0) stable; urgency=medium
17
28 [ Dan Callahan ]
99 * [Undoing room shutdowns](#undoing-room-shutdowns)
1010 - [Make Room Admin API](#make-room-admin-api)
1111 - [Forward Extremities Admin API](#forward-extremities-admin-api)
12 - [Event Context API](#event-context-api)
1213
1314 # List Room API
1415
593594 "deleted": 1
594595 }
595596 ```
597
598 # Event Context API
599
600 This API lets a client find the context of an event. This is designed primarily to investigate abuse reports.
601
602 ```
603 GET /_synapse/admin/v1/rooms/<room_id>/context/<event_id>
604 ```
605
606 This API mimmicks [GET /_matrix/client/r0/rooms/{roomId}/context/{eventId}](https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-rooms-roomid-context-eventid). Please refer to the link for all details on parameters and reseponse.
607
608 Example response:
609
610 ```json
611 {
612 "end": "t29-57_2_0_2",
613 "events_after": [
614 {
615 "content": {
616 "body": "This is an example text message",
617 "msgtype": "m.text",
618 "format": "org.matrix.custom.html",
619 "formatted_body": "<b>This is an example text message</b>"
620 },
621 "type": "m.room.message",
622 "event_id": "$143273582443PhrSn:example.org",
623 "room_id": "!636q39766251:example.com",
624 "sender": "@example:example.org",
625 "origin_server_ts": 1432735824653,
626 "unsigned": {
627 "age": 1234
628 }
629 }
630 ],
631 "event": {
632 "content": {
633 "body": "filename.jpg",
634 "info": {
635 "h": 398,
636 "w": 394,
637 "mimetype": "image/jpeg",
638 "size": 31037
639 },
640 "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
641 "msgtype": "m.image"
642 },
643 "type": "m.room.message",
644 "event_id": "$f3h4d129462ha:example.com",
645 "room_id": "!636q39766251:example.com",
646 "sender": "@example:example.org",
647 "origin_server_ts": 1432735824653,
648 "unsigned": {
649 "age": 1234
650 }
651 },
652 "events_before": [
653 {
654 "content": {
655 "body": "something-important.doc",
656 "filename": "something-important.doc",
657 "info": {
658 "mimetype": "application/msword",
659 "size": 46144
660 },
661 "msgtype": "m.file",
662 "url": "mxc://example.org/FHyPlCeYUSFFxlgbQYZmoEoe"
663 },
664 "type": "m.room.message",
665 "event_id": "$143273582443PhrSn:example.org",
666 "room_id": "!636q39766251:example.com",
667 "sender": "@example:example.org",
668 "origin_server_ts": 1432735824653,
669 "unsigned": {
670 "age": 1234
671 }
672 }
673 ],
674 "start": "t27-54_2_0_2",
675 "state": [
676 {
677 "content": {
678 "creator": "@example:example.org",
679 "room_version": "1",
680 "m.federate": true,
681 "predecessor": {
682 "event_id": "$something:example.org",
683 "room_id": "!oldroom:example.org"
684 }
685 },
686 "type": "m.room.create",
687 "event_id": "$143273582443PhrSn:example.org",
688 "room_id": "!636q39766251:example.com",
689 "sender": "@example:example.org",
690 "origin_server_ts": 1432735824653,
691 "unsigned": {
692 "age": 1234
693 },
694 "state_key": ""
695 },
696 {
697 "content": {
698 "membership": "join",
699 "avatar_url": "mxc://example.org/SEsfnsuifSDFSSEF",
700 "displayname": "Alice Margatroid"
701 },
702 "type": "m.room.member",
703 "event_id": "$143273582443PhrSn:example.org",
704 "room_id": "!636q39766251:example.com",
705 "sender": "@example:example.org",
706 "origin_server_ts": 1432735824653,
707 "unsigned": {
708 "age": 1234
709 },
710 "state_key": "@alice:example.org"
711 }
712 ]
713 }
714 ```
2828 }
2929 ],
3030 "avatar_url": "<avatar_url>",
31 "admin": false,
32 "deactivated": false,
31 "admin": 0,
32 "deactivated": 0,
33 "shadow_banned": 0,
3334 "password_hash": "$2b$12$p9B4GkqYdRTPGD",
3435 "creation_ts": 1560432506,
3536 "appservice_id": null,
149150 "admin": 0,
150151 "user_type": null,
151152 "deactivated": 0,
153 "shadow_banned": 0,
152154 "displayname": "<User One>",
153155 "avatar_url": null
154156 }, {
157159 "admin": 1,
158160 "user_type": null,
159161 "deactivated": 0,
162 "shadow_banned": 0,
160163 "displayname": "<User Two>",
161164 "avatar_url": "<avatar_url>"
162165 }
261264 - Reject all pending invites
262265 - Remove all account validity information related to the user
263266
264 The following additional actions are performed during deactivation if``erase``
267 The following additional actions are performed during deactivation if ``erase``
265268 is set to ``true``:
266269
267270 - Remove the user's display name
77
88 The necessary tools are detailed below.
99
10 First install them with:
11
12 pip install -e ".[lint,mypy]"
13
1014 - **black**
1115
1216 The Synapse codebase uses [black](https://pypi.org/project/black/)
1317 as an opinionated code formatter, ensuring all comitted code is
1418 properly formatted.
15
16 First install `black` with:
17
18 pip install --upgrade black
1919
2020 Have `black` auto-format your code (it shouldn't change any
2121 functionality) with:
2727 `flake8` is a code checking tool. We require code to pass `flake8`
2828 before being merged into the codebase.
2929
30 Install `flake8` with:
31
32 pip install --upgrade flake8 flake8-comprehensions
33
3430 Check all application and test code with:
3531
3632 flake8 synapse tests
3935
4036 `isort` ensures imports are nicely formatted, and can suggest and
4137 auto-fix issues such as double-importing.
42
43 Install `isort` with:
44
45 pip install --upgrade isort
4638
4739 Auto-fix imports with:
4840
364364 does not return a `sub` property, an alternative `subject_claim` has to be set.
365365
366366 1. Create a new application.
367 2. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback`
367 2. Add this Callback URL: `[synapse public baseurl]/_synapse/client/oidc/callback`
368368
369369 Synapse config:
370370
387387 localpart_template: "{{ user.login }}"
388388 display_name_template: "{{ user.full_name }}"
389389 ```
390
391 ### XWiki
392
393 Install [OpenID Connect Provider](https://extensions.xwiki.org/xwiki/bin/view/Extension/OpenID%20Connect/OpenID%20Connect%20Provider/) extension in your [XWiki](https://www.xwiki.org) instance.
394
395 Synapse config:
396
397 ```yaml
398 oidc_providers:
399 - idp_id: xwiki
400 idp_name: "XWiki"
401 issuer: "https://myxwikihost/xwiki/oidc/"
402 client_id: "your-client-id" # TO BE FILLED
403 # Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed
404 client_secret: "dontcare"
405 scopes: ["openid", "profile"]
406 user_profile_method: "userinfo_endpoint"
407 user_mapping_provider:
408 config:
409 localpart_template: "{{ user.preferred_username }}"
410 display_name_template: "{{ user.name }}"
411 ```
3939
4040 ```
4141 server {
42 listen 443 ssl;
43 listen [::]:443 ssl;
42 listen 443 ssl http2;
43 listen [::]:443 ssl http2;
4444
4545 # For the federation port
46 listen 8448 ssl default_server;
47 listen [::]:8448 ssl default_server;
46 listen 8448 ssl http2 default_server;
47 listen [::]:8448 ssl http2 default_server;
4848
4949 server_name matrix.example.com;
5050
164164 # - '100.64.0.0/10'
165165 # - '192.0.0.0/24'
166166 # - '169.254.0.0/16'
167 # - '192.88.99.0/24'
167168 # - '198.18.0.0/15'
168169 # - '192.0.2.0/24'
169170 # - '198.51.100.0/24'
172173 # - '::1/128'
173174 # - 'fe80::/10'
174175 # - 'fc00::/7'
176 # - '2001:db8::/32'
177 # - 'ff00::/8'
178 # - 'fec0::/10'
175179
176180 # List of IP address CIDR ranges that should be allowed for federation,
177181 # identity servers, push servers, and for checking key validity for
989993 # - '100.64.0.0/10'
990994 # - '192.0.0.0/24'
991995 # - '169.254.0.0/16'
996 # - '192.88.99.0/24'
992997 # - '198.18.0.0/15'
993998 # - '192.0.2.0/24'
994999 # - '198.51.100.0/24'
9971002 # - '::1/128'
9981003 # - 'fe80::/10'
9991004 # - 'fc00::/7'
1005 # - '2001:db8::/32'
1006 # - 'ff00::/8'
1007 # - 'fec0::/10'
10001008
10011009 # List of IP address CIDR ranges that the URL preview spider is allowed
10021010 # to access even if they are specified in url_preview_ip_range_blacklist.
13171325 # By default, any room aliases included in this list will be created
13181326 # as a publicly joinable room when the first user registers for the
13191327 # homeserver. This behaviour can be customised with the settings below.
1328 # If the room already exists, make certain it is a publicly joinable
1329 # room. The join rule of the room must be set to 'public'.
13201330 #
13211331 #auto_join_rooms:
13221332 # - "#example:example.com"
18591869 # user_mapping_provider:
18601870 # config:
18611871 # subject_claim: "id"
1862 # localpart_template: "{ user.login }"
1863 # display_name_template: "{ user.name }"
1864 # email_template: "{ user.email }"
1872 # localpart_template: "{{ user.login }}"
1873 # display_name_template: "{{ user.name }}"
1874 # email_template: "{{ user.email }}"
18651875
18661876 # For use with Keycloak
18671877 #
18881898 # user_mapping_provider:
18891899 # config:
18901900 # subject_claim: "id"
1891 # localpart_template: "{ user.login }"
1892 # display_name_template: "{ user.name }"
1901 # localpart_template: "{{ user.login }}"
1902 # display_name_template: "{{ user.name }}"
18931903
18941904
18951905 # Enable Central Authentication Service (CAS) for registration and login.
22172227 #require_uppercase: true
22182228
22192229 ui_auth:
2220 # The number of milliseconds to allow a user-interactive authentication
2221 # session to be active.
2230 # The amount of time to allow a user-interactive authentication session
2231 # to be active.
22222232 #
22232233 # This defaults to 0, meaning the user is queried for their credentials
2224 # before every action, but this can be overridden to alow a single
2234 # before every action, but this can be overridden to allow a single
22252235 # validation to be re-used. This weakens the protections afforded by
22262236 # the user-interactive authentication process, by allowing for multiple
22272237 # (and potentially different) operations to use the same validation session.
22292239 # Uncomment below to allow for credential validation to last for 15
22302240 # seconds.
22312241 #
2232 #session_timeout: 15000
2242 #session_timeout: "15s"
22332243
22342244
22352245 # Configuration for sending emails from Synapse.
6060
6161 async def check_registration_for_spam(self, email_threepid, username, request_info):
6262 return RegistrationBehaviour.ALLOW # allow all registrations
63
64 async def check_media_file_for_spam(self, file_wrapper, file_info):
65 return False # allow all media
6366 ```
6467
6568 ## Configuration
186186 ```
187187 * If you use systemd:
188188 ```
189 systemctl restart synapse.service
189 systemctl restart matrix-synapse.service
190190 ```
191191 ... and then reload any clients (or wait an hour for them to refresh their
192192 settings).
275275
276276 Ensure that all SSO logins go to a single process.
277277 For multiple workers not handling the SSO endpoints properly, see
278 [#7530](https://github.com/matrix-org/synapse/issues/7530).
278 [#7530](https://github.com/matrix-org/synapse/issues/7530) and
279 [#9427](https://github.com/matrix-org/synapse/issues/9427).
279280
280281 Note that a HTTP listener with `client` and `federation` resources must be
281282 configured in the `worker_listeners` option in the worker config.
372373 REST endpoints itself, but you should set `start_pushers: False` in the
373374 shared configuration file to stop the main synapse sending push notifications.
374375
375 Note this worker cannot be load-balanced: only one instance should be active.
376 To run multiple instances at once the `pusher_instances` option should list all
377 pusher instances by their worker name, e.g.:
378
379 ```yaml
380 pusher_instances:
381 - pusher_worker1
382 - pusher_worker2
383 ```
384
376385
377386 ### `synapse.app.appservice`
378387
2222 synapse/events/validator.py,
2323 synapse/events/spamcheck.py,
2424 synapse/federation,
25 synapse/groups,
2526 synapse/handlers,
2627 synapse/http/client.py,
2728 synapse/http/federation/matrix_federation_agent.py,
161161 fi
162162
163163 # Delete schema_version, applied_schema_deltas and applied_module_schemas tables
164 # Also delete any shadow tables from fts4
164165 # This needs to be done after synapse_port_db is run
165166 echo "Dropping unwanted db tables..."
166167 SQL="
167168 DROP TABLE schema_version;
168169 DROP TABLE applied_schema_deltas;
169170 DROP TABLE applied_module_schemas;
171 DROP TABLE event_search_content;
172 DROP TABLE event_search_segments;
173 DROP TABLE event_search_segdir;
174 DROP TABLE event_search_docsize;
175 DROP TABLE event_search_stat;
176 DROP TABLE user_directory_search_content;
177 DROP TABLE user_directory_search_segments;
178 DROP TABLE user_directory_search_segdir;
179 DROP TABLE user_directory_search_docsize;
180 DROP TABLE user_directory_search_stat;
170181 "
171182 sqlite3 "$SQLITE_DB" <<< "$SQL"
172183 psql $POSTGRES_DB_NAME -U "$POSTGRES_USERNAME" -w <<< "$SQL"
8686 arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
8787
8888 signature = signature.copy_modified(
89 arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
89 arg_types=arg_types,
90 arg_names=arg_names,
91 arg_kinds=arg_kinds,
9092 )
9193
9294 return signature
9696 # We pin black so that our tests don't start failing on new releases.
9797 CONDITIONAL_REQUIREMENTS["lint"] = [
9898 "isort==5.7.0",
99 "black==19.10b0",
99 "black==20.8b1",
100100 "flake8-comprehensions",
101101 "flake8",
102102 ]
8888 def __reduce__(
8989 self,
9090 ) -> Tuple[
91 Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
91 Type[SortedDict[_KT, _VT]],
92 Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
9293 ]: ...
9394 def __repr__(self) -> str: ...
9495 def _check(self) -> None: ...
9596 def islice(
96 self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
97 self,
98 start: Optional[int] = ...,
99 stop: Optional[int] = ...,
100 reverse=bool,
97101 ) -> Iterator[_KT]: ...
98102 def bisect_left(self, value: _KT) -> int: ...
99103 def bisect_right(self, value: _KT) -> int: ...
3030
3131 DEFAULT_LOAD_FACTOR: int = ...
3232 def __init__(
33 self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ...,
33 self,
34 iterable: Optional[Iterable[_T]] = ...,
35 key: Optional[_Key[_T]] = ...,
3436 ): ...
3537 # NB: currently mypy does not honour return type, see mypy #3307
3638 @overload
7577 def __len__(self) -> int: ...
7678 def reverse(self) -> None: ...
7779 def islice(
78 self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
80 self,
81 start: Optional[int] = ...,
82 stop: Optional[int] = ...,
83 reverse=bool,
7984 ) -> Iterator[_T]: ...
8085 def _islice(
81 self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool,
86 self,
87 min_pos: int,
88 min_idx: int,
89 max_pos: int,
90 max_idx: int,
91 reverse: bool,
8292 ) -> Iterator[_T]: ...
8393 def irange(
8494 self,
4747 except ImportError:
4848 pass
4949
50 __version__ = "1.27.0"
50 __version__ = "1.28.0"
5151
5252 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
5353 # We import here so that we don't have to install a bunch of deps when
167167 rights: str = "access",
168168 allow_expired: bool = False,
169169 ) -> synapse.types.Requester:
170 """ Get a registered user's ID.
170 """Get a registered user's ID.
171171
172172 Args:
173173 request: An HTTP request with an access_token query parameter.
293293 return user_id, app_service
294294
295295 async def get_user_by_access_token(
296 self, token: str, rights: str = "access", allow_expired: bool = False,
296 self,
297 token: str,
298 rights: str = "access",
299 allow_expired: bool = False,
297300 ) -> TokenLookupResult:
298 """ Validate access token and get user_id from it
301 """Validate access token and get user_id from it
299302
300303 Args:
301304 token: The access token to get the user by
488491 return service
489492
490493 async def is_server_admin(self, user: UserID) -> bool:
491 """ Check if the given user is a local server admin.
494 """Check if the given user is a local server admin.
492495
493496 Args:
494497 user: user to check
499502 return await self.store.is_server_admin(user)
500503
501504 def compute_auth_events(
502 self, event, current_state_ids: StateMap[str], for_verification: bool = False,
505 self,
506 event,
507 current_state_ids: StateMap[str],
508 for_verification: bool = False,
503509 ) -> List[str]:
504510 """Given an event and current state return the list of event IDs used
505511 to auth an event.
2525
2626 # the maximum length for a user id is 255 characters
2727 MAX_USERID_LENGTH = 255
28
29 # The maximum length for a group id is 255 characters
30 MAX_GROUPID_LENGTH = 255
31 MAX_GROUP_CATEGORYID_LENGTH = 255
32 MAX_GROUP_ROLEID_LENGTH = 255
2833
2934
3035 class Membership:
127132
128133
129134 class RelationTypes:
130 """The types of relations known to this server.
131 """
135 """The types of relations known to this server."""
132136
133137 ANNOTATION = "m.annotation"
134138 REPLACE = "m.replace"
389389
390390
391391 class LimitExceededError(SynapseError):
392 """A client has sent too many requests and is being throttled.
393 """
392 """A client has sent too many requests and is being throttled."""
394393
395394 def __init__(
396395 self,
407406
408407
409408 class RoomKeysVersionError(SynapseError):
410 """A client has tried to upload to a non-current version of the room_keys store
411 """
409 """A client has tried to upload to a non-current version of the room_keys store"""
412410
413411 def __init__(self, current_version: str):
414412 """
425423
426424 def __init__(self, msg: str = "Homeserver does not support this room version"):
427425 super().__init__(
428 code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION,
426 code=400,
427 msg=msg,
428 errcode=Codes.UNSUPPORTED_ROOM_VERSION,
429429 )
430430
431431
460460
461461
462462 class PasswordRefusedError(SynapseError):
463 """A password has been refused, either during password reset/change or registration.
464 """
463 """A password has been refused, either during password reset/change or registration."""
465464
466465 def __init__(
467466 self,
469468 errcode: str = Codes.WEAK_PASSWORD,
470469 ):
471470 super().__init__(
472 code=400, msg=msg, errcode=errcode,
471 code=400,
472 msg=msg,
473 errcode=errcode,
473474 )
474475
475476
492493
493494
494495 def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
495 """ Utility method for constructing an error response for client-server
496 """Utility method for constructing an error response for client-server
496497 interactions.
497498
498499 Args:
509510
510511
511512 class FederationError(RuntimeError):
512 """ This class is used to inform remote homeservers about erroneous
513 """This class is used to inform remote homeservers about erroneous
513514 PDUs they sent us.
514515
515516 FATAL: The remote server could not interpret the source event.
5555
5656 @classmethod
5757 def default(cls, user_id):
58 """Returns a default presence state.
59 """
58 """Returns a default presence state."""
6059 return cls(
6160 user_id=user_id,
6261 state=PresenceState.OFFLINE,
5757
5858
5959 def start_worker_reactor(appname, config, run_command=reactor.run):
60 """ Run the reactor in the main process
60 """Run the reactor in the main process
6161
6262 Daemonizes if necessary, and then configures some resources, before starting
6363 the reactor. Pulls configuration from the 'worker' settings in 'config'.
9292 logger,
9393 run_command=reactor.run,
9494 ):
95 """ Run the reactor in the main process
95 """Run the reactor in the main process
9696
9797 Daemonizes if necessary, and then configures some resources, before starting
9898 the reactor
312312 refresh_certificate(hs)
313313
314314 # Start the tracer
315 synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
316 hs
317 )
315 synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
318316
319317 # It is now safe to start your Synapse.
320318 hs.start_listening(listeners)
369367
370368
371369 def setup_sdnotify(hs):
372 """Adds process state hooks to tell systemd what we are up to.
373 """
370 """Adds process state hooks to tell systemd what we are up to."""
374371
375372 # Tell systemd our state, if we're using it. This will silently fail if
376373 # we're not using systemd.
404401
405402
406403 class _LimitedHostnameResolver:
407 """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
408 """
404 """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups."""
409405
410406 def __init__(self, resolver, max_dns_requests_in_flight):
411407 self._resolver = resolver
420420 ]
421421
422422 async def set_state(self, target_user, state, ignore_status_msg=False):
423 """Set the presence state of the user.
424 """
423 """Set the presence state of the user."""
425424 presence = state["presence"]
426425
427426 valid_presence = (
165165
166166 @cached(num_args=1, cache_context=True)
167167 async def matches_user_in_member_list(
168 self, room_id: str, store: "DataStore", cache_context: _CacheContext,
168 self,
169 room_id: str,
170 store: "DataStore",
171 cache_context: _CacheContext,
169172 ) -> bool:
170173 """Check if this service is interested a room based upon it's membership
171174
7575 fields = r["fields"]
7676 if not isinstance(fields, dict):
7777 return False
78 for k in fields.keys():
79 if not isinstance(fields[k], str):
80 return False
8178
8279 return True
8380
229226
230227 try:
231228 await self.put_json(
232 uri=uri, json_body=body, args={"access_token": service.hs_token},
229 uri=uri,
230 json_body=body,
231 args={"access_token": service.hs_token},
233232 )
234233 sent_transactions_counter.labels(service.id).inc()
235234 sent_events_counter.labels(service.id).inc(len(events))
6767
6868
6969 class ApplicationServiceScheduler:
70 """ Public facing API for this module. Does the required DI to tie the
70 """Public facing API for this module. Does the required DI to tie the
7171 components together. This also serves as the "event_pool", which in this
7272 case is a simple array.
7373 """
223223 return self.read_templates([filename])[0]
224224
225225 def read_templates(
226 self, filenames: List[str], custom_template_directory: Optional[str] = None,
226 self,
227 filenames: List[str],
228 custom_template_directory: Optional[str] = None,
227229 ) -> List[jinja2.Template]:
228230 """Load a list of template files from disk using the given variables.
229231
263265
264266 # TODO: switch to synapse.util.templates.build_jinja_env
265267 loader = jinja2.FileSystemLoader(search_directories)
266 env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),)
268 env = jinja2.Environment(
269 loader=loader,
270 autoescape=jinja2.select_autoescape(),
271 )
267272
268273 # Update the environment with our custom filters
269274 env.filters.update(
824829 instances = attr.ib(type=List[str])
825830
826831 def should_handle(self, instance_name: str, key: str) -> bool:
827 """Whether this instance is responsible for handling the given key.
828 """
832 """Whether this instance is responsible for handling the given key."""
829833 # If multiple instances are not defined we always return true
830834 if not self.instances or len(self.instances) == 1:
831835 return True
1717
1818
1919 class AuthConfig(Config):
20 """Password and login configuration
21 """
20 """Password and login configuration"""
2221
2322 section = "auth"
2423
3736
3837 # User-interactive authentication
3938 ui_auth = config.get("ui_auth") or {}
40 self.ui_auth_session_timeout = ui_auth.get("session_timeout", 0)
39 self.ui_auth_session_timeout = self.parse_duration(
40 ui_auth.get("session_timeout", 0)
41 )
4142
4243 def generate_config_section(self, config_dir_path, server_name, **kwargs):
4344 return """\
9394 #require_uppercase: true
9495
9596 ui_auth:
96 # The number of milliseconds to allow a user-interactive authentication
97 # session to be active.
97 # The amount of time to allow a user-interactive authentication session
98 # to be active.
9899 #
99100 # This defaults to 0, meaning the user is queried for their credentials
100 # before every action, but this can be overridden to alow a single
101 # before every action, but this can be overridden to allow a single
101102 # validation to be re-used. This weakens the protections afforded by
102103 # the user-interactive authentication process, by allowing for multiple
103104 # (and potentially different) operations to use the same validation session.
105106 # Uncomment below to allow for credential validation to last for 15
106107 # seconds.
107108 #
108 #session_timeout: 15000
109 #session_timeout: "15s"
109110 """
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15 from typing import Any, List
16
17 from synapse.config.sso import SsoAttributeRequirement
18
1519 from ._base import Config, ConfigError
20 from ._util import validate_config
1621
1722
1823 class CasConfig(Config):
3944 # TODO Update this to a _synapse URL.
4045 self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket"
4146 self.cas_displayname_attribute = cas_config.get("displayname_attribute")
42 self.cas_required_attributes = cas_config.get("required_attributes") or {}
47 required_attributes = cas_config.get("required_attributes") or {}
48 self.cas_required_attributes = _parsed_required_attributes_def(
49 required_attributes
50 )
51
4352 else:
4453 self.cas_server_url = None
4554 self.cas_service_url = None
4655 self.cas_displayname_attribute = None
47 self.cas_required_attributes = {}
56 self.cas_required_attributes = []
4857
4958 def generate_config_section(self, config_dir_path, server_name, **kwargs):
5059 return """\
7685 # userGroup: "staff"
7786 # department: None
7887 """
88
89
90 # CAS uses a legacy required attributes mapping, not the one provided by
91 # SsoAttributeRequirement.
92 REQUIRED_ATTRIBUTES_SCHEMA = {
93 "type": "object",
94 "additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]},
95 }
96
97
98 def _parsed_required_attributes_def(
99 required_attributes: Any,
100 ) -> List[SsoAttributeRequirement]:
101 validate_config(
102 REQUIRED_ATTRIBUTES_SCHEMA,
103 required_attributes,
104 config_path=("cas_config", "required_attributes"),
105 )
106 return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()]
206206 )
207207
208208 def get_single_database(self) -> DatabaseConnectionConfig:
209 """Returns the database if there is only one, useful for e.g. tests
210 """
209 """Returns the database if there is only one, useful for e.g. tests"""
211210 if not self.databases:
212211 raise Exception("More than one database exists")
213212
288288 self.email_notif_template_html,
289289 self.email_notif_template_text,
290290 ) = self.read_templates(
291 [notif_template_html, notif_template_text], template_dir,
291 [notif_template_html, notif_template_text],
292 template_dir,
292293 )
293294
294295 self.email_notif_for_new_users = email_config.get(
310311 self.account_validity_template_html,
311312 self.account_validity_template_text,
312313 ) = self.read_templates(
313 [expiry_template_html, expiry_template_text], template_dir,
314 [expiry_template_html, expiry_template_text],
315 template_dir,
314316 )
315317
316318 subjects_config = email_config.get("subjects", {})
161161 )
162162
163163 logging_group.add_argument(
164 "-f", "--log-file", dest="log_file", help=argparse.SUPPRESS,
164 "-f",
165 "--log-file",
166 dest="log_file",
167 help=argparse.SUPPRESS,
165168 )
166169
167170 def generate_files(self, config, config_dir_path):
200200 # user_mapping_provider:
201201 # config:
202202 # subject_claim: "id"
203 # localpart_template: "{{ user.login }}"
204 # display_name_template: "{{ user.name }}"
205 # email_template: "{{ user.email }}"
203 # localpart_template: "{{{{ user.login }}}}"
204 # display_name_template: "{{{{ user.name }}}}"
205 # email_template: "{{{{ user.email }}}}"
206206
207207 # For use with Keycloak
208208 #
229229 # user_mapping_provider:
230230 # config:
231231 # subject_claim: "id"
232 # localpart_template: "{{ user.login }}"
233 # display_name_template: "{{ user.name }}"
232 # localpart_template: "{{{{ user.login }}}}"
233 # display_name_template: "{{{{ user.name }}}}"
234234 """.format(
235235 mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
236236 )
354354 ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
355355 ump_config.setdefault("config", {})
356356
357 (user_mapping_provider_class, user_mapping_provider_config,) = load_module(
358 ump_config, config_path + ("user_mapping_provider",)
359 )
357 (
358 user_mapping_provider_class,
359 user_mapping_provider_config,
360 ) = load_module(ump_config, config_path + ("user_mapping_provider",))
360361
361362 # Ensure loaded user mapping module has defined all necessary methods
362363 required_methods = [
371372 if missing_methods:
372373 raise ConfigError(
373374 "Class %s is missing required "
374 "methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),),
375 "methods: %s"
376 % (
377 user_mapping_provider_class,
378 ", ".join(missing_methods),
379 ),
375380 config_path + ("user_mapping_provider", "module"),
376381 )
377382
390390 # By default, any room aliases included in this list will be created
391391 # as a publicly joinable room when the first user registers for the
392392 # homeserver. This behaviour can be customised with the settings below.
393 # If the room already exists, make certain it is a publicly joinable
394 # room. The join rule of the room must be set to 'public'.
393395 #
394396 #auto_join_rooms:
395397 # - "#example:example.com"
1616 from collections import namedtuple
1717 from typing import Dict, List
1818
19 from netaddr import IPSet
20
21 from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST
19 from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
2220 from synapse.python_dependencies import DependencyException, check_requirements
2321 from synapse.util.module_loader import load_module
2422
5351
5452
5553 def parse_thumbnail_requirements(thumbnail_sizes):
56 """ Takes a list of dictionaries with "width", "height", and "method" keys
54 """Takes a list of dictionaries with "width", "height", and "method" keys
5755 and creates a map from image media types to the thumbnail size, thumbnailing
5856 method, and thumbnail media type to precalculate
5957
186184 "to work"
187185 )
188186
189 self.url_preview_ip_range_blacklist = IPSet(
190 config["url_preview_ip_range_blacklist"]
191 )
192
193187 # we always blacklist '0.0.0.0' and '::', which are supposed to be
194188 # unroutable addresses.
195 self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"])
196
197 self.url_preview_ip_range_whitelist = IPSet(
198 config.get("url_preview_ip_range_whitelist", ())
189 self.url_preview_ip_range_blacklist = generate_ip_set(
190 config["url_preview_ip_range_blacklist"],
191 ["0.0.0.0", "::"],
192 config_path=("url_preview_ip_range_blacklist",),
193 )
194
195 self.url_preview_ip_range_whitelist = generate_ip_set(
196 config.get("url_preview_ip_range_whitelist", ()),
197 config_path=("url_preview_ip_range_whitelist",),
199198 )
200199
201200 self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ())
122122 alias (str)
123123
124124 Returns:
125 boolean: True if user is allowed to crate the alias
125 boolean: True if user is allowed to create the alias
126126 """
127127 for rule in self._alias_creation_rules:
128128 if rule.matches(user_id, room_id, [alias]):
1616 import logging
1717 from typing import Any, List
1818
19 import attr
20
19 from synapse.config.sso import SsoAttributeRequirement
2120 from synapse.python_dependencies import DependencyException, check_requirements
2221 from synapse.util.module_loader import load_module, load_python_module
2322
397396 }
398397
399398
400 @attr.s(frozen=True)
401 class SamlAttributeRequirement:
402 """Object describing a single requirement for SAML attributes."""
403
404 attribute = attr.ib(type=str)
405 value = attr.ib(type=str)
406
407 JSON_SCHEMA = {
408 "type": "object",
409 "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
410 "required": ["attribute", "value"],
411 }
412
413
414399 ATTRIBUTE_REQUIREMENTS_SCHEMA = {
415400 "type": "array",
416 "items": SamlAttributeRequirement.JSON_SCHEMA,
401 "items": SsoAttributeRequirement.JSON_SCHEMA,
417402 }
418403
419404
420405 def _parse_attribute_requirements_def(
421406 attribute_requirements: Any,
422 ) -> List[SamlAttributeRequirement]:
407 ) -> List[SsoAttributeRequirement]:
423408 validate_config(
424409 ATTRIBUTE_REQUIREMENTS_SCHEMA,
425410 attribute_requirements,
426 config_path=["saml2_config", "attribute_requirements"],
411 config_path=("saml2_config", "attribute_requirements"),
427412 )
428 return [SamlAttributeRequirement(**x) for x in attribute_requirements]
413 return [SsoAttributeRequirement(**x) for x in attribute_requirements]
1414 # See the License for the specific language governing permissions and
1515 # limitations under the License.
1616
17 import itertools
1718 import logging
1819 import os.path
1920 import re
2223
2324 import attr
2425 import yaml
25 from netaddr import IPSet
26 from netaddr import AddrFormatError, IPNetwork, IPSet
2627
2728 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
2829 from synapse.util.stringutils import parse_and_validate_server_name
3940 # in the list.
4041 DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
4142
43
44 def _6to4(network: IPNetwork) -> IPNetwork:
45 """Convert an IPv4 network into a 6to4 IPv6 network per RFC 3056."""
46
47 # 6to4 networks consist of:
48 # * 2002 as the first 16 bits
49 # * The first IPv4 address in the network hex-encoded as the next 32 bits
50 # * The new prefix length needs to include the bits from the 2002 prefix.
51 hex_network = hex(network.first)[2:]
52 hex_network = ("0" * (8 - len(hex_network))) + hex_network
53 return IPNetwork(
54 "2002:%s:%s::/%d"
55 % (
56 hex_network[:4],
57 hex_network[4:],
58 16 + network.prefixlen,
59 )
60 )
61
62
63 def generate_ip_set(
64 ip_addresses: Optional[Iterable[str]],
65 extra_addresses: Optional[Iterable[str]] = None,
66 config_path: Optional[Iterable[str]] = None,
67 ) -> IPSet:
68 """
69 Generate an IPSet from a list of IP addresses or CIDRs.
70
71 Additionally, for each IPv4 network in the list of IP addresses, also
72 includes the corresponding IPv6 networks.
73
74 This includes:
75
76 * IPv4-Compatible IPv6 Address (see RFC 4291, section 2.5.5.1)
77 * IPv4-Mapped IPv6 Address (see RFC 4291, section 2.5.5.2)
78 * 6to4 Address (see RFC 3056, section 2)
79
80 Args:
81 ip_addresses: An iterable of IP addresses or CIDRs.
82 extra_addresses: An iterable of IP addresses or CIDRs.
83 config_path: The path in the configuration for error messages.
84
85 Returns:
86 A new IP set.
87 """
88 result = IPSet()
89 for ip in itertools.chain(ip_addresses or (), extra_addresses or ()):
90 try:
91 network = IPNetwork(ip)
92 except AddrFormatError as e:
93 raise ConfigError(
94 "Invalid IP range provided: %s." % (ip,), config_path
95 ) from e
96 result.add(network)
97
98 # It is possible that these already exist in the set, but that's OK.
99 if ":" not in str(network):
100 result.add(IPNetwork(network).ipv6(ipv4_compatible=True))
101 result.add(IPNetwork(network).ipv6(ipv4_compatible=False))
102 result.add(_6to4(network))
103
104 return result
105
106
107 # IP ranges that are considered private / unroutable / don't make sense.
42108 DEFAULT_IP_RANGE_BLACKLIST = [
43109 # Localhost
44110 "127.0.0.0/8",
52118 "192.0.0.0/24",
53119 # Link-local networks.
54120 "169.254.0.0/16",
121 # Formerly used for 6to4 relay.
122 "192.88.99.0/24",
55123 # Testing networks.
56124 "198.18.0.0/15",
57125 "192.0.2.0/24",
65133 "fe80::/10",
66134 # Unique local addresses.
67135 "fc00::/7",
136 # Testing networks.
137 "2001:db8::/32",
138 # Multicast.
139 "ff00::/8",
140 # Site-local addresses
141 "fec0::/10",
68142 ]
69143
70144 DEFAULT_ROOM_VERSION = "6"
184258 # Whether to require sharing a room with a user to retrieve their
185259 # profile data
186260 self.limit_profile_requests_to_users_who_share_rooms = config.get(
187 "limit_profile_requests_to_users_who_share_rooms", False,
261 "limit_profile_requests_to_users_who_share_rooms",
262 False,
188263 )
189264
190265 if "restrict_public_rooms_to_local_users" in config and (
289364 )
290365
291366 # Attempt to create an IPSet from the given ranges
292 try:
293 self.ip_range_blacklist = IPSet(ip_range_blacklist)
294 except Exception as e:
295 raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e
367
296368 # Always blacklist 0.0.0.0, ::
297 self.ip_range_blacklist.update(["0.0.0.0", "::"])
298
299 try:
300 self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ()))
301 except Exception as e:
302 raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e
369 self.ip_range_blacklist = generate_ip_set(
370 ip_range_blacklist, ["0.0.0.0", "::"], config_path=("ip_range_blacklist",)
371 )
372
373 self.ip_range_whitelist = generate_ip_set(
374 config.get("ip_range_whitelist", ()), config_path=("ip_range_whitelist",)
375 )
303376
304377 # The federation_ip_range_blacklist is used for backwards-compatibility
305378 # and only applies to federation and identity servers. If it is not given,
307380 federation_ip_range_blacklist = config.get(
308381 "federation_ip_range_blacklist", ip_range_blacklist
309382 )
310 try:
311 self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist)
312 except Exception as e:
313 raise ConfigError(
314 "Invalid range(s) provided in federation_ip_range_blacklist."
315 ) from e
316383 # Always blacklist 0.0.0.0, ::
317 self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
384 self.federation_ip_range_blacklist = generate_ip_set(
385 federation_ip_range_blacklist,
386 ["0.0.0.0", "::"],
387 config_path=("federation_ip_range_blacklist",),
388 )
318389
319390 if self.public_baseurl is not None:
320391 if self.public_baseurl[-1] != "/":
548619 if manhole:
549620 self.listeners.append(
550621 ListenerConfig(
551 port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
622 port=manhole,
623 bind_addresses=["127.0.0.1"],
624 type="manhole",
552625 )
553626 )
554627
584657 # and letting the client know which email address is bound to an account and
585658 # which one isn't.
586659 self.request_token_inhibit_3pid_errors = config.get(
587 "request_token_inhibit_3pid_errors", False,
660 "request_token_inhibit_3pid_errors",
661 False,
588662 )
589663
590664 # List of users trialing the new experimental default push rules. This setting is
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import Any, Dict
14 from typing import Any, Dict, Optional
15
16 import attr
1517
1618 from ._base import Config
1719
1820
21 @attr.s(frozen=True)
22 class SsoAttributeRequirement:
23 """Object describing a single requirement for SSO attributes."""
24
25 attribute = attr.ib(type=str)
26 # If a value is not given, than the attribute must simply exist.
27 value = attr.ib(type=Optional[str])
28
29 JSON_SCHEMA = {
30 "type": "object",
31 "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
32 "required": ["attribute", "value"],
33 }
34
35
1936 class SSOConfig(Config):
20 """SSO Configuration
21 """
37 """SSO Configuration"""
2238
2339 section = "sso"
2440
3232
3333 @attr.s
3434 class InstanceLocationConfig:
35 """The host and port to talk to an instance via HTTP replication.
36 """
35 """The host and port to talk to an instance via HTTP replication."""
3736
3837 host = attr.ib(type=str)
3938 port = attr.ib(type=int)
5352 )
5453 typing = attr.ib(default="master", type=str)
5554 to_device = attr.ib(
56 default=["master"], type=List[str], converter=_instance_to_list_converter,
55 default=["master"],
56 type=List[str],
57 converter=_instance_to_list_converter,
5758 )
5859 account_data = attr.ib(
59 default=["master"], type=List[str], converter=_instance_to_list_converter,
60 default=["master"],
61 type=List[str],
62 converter=_instance_to_list_converter,
6063 )
6164 receipts = attr.ib(
62 default=["master"], type=List[str], converter=_instance_to_list_converter,
65 default=["master"],
66 type=List[str],
67 converter=_instance_to_list_converter,
6368 )
6469
6570
106111 if manhole:
107112 self.worker_listeners.append(
108113 ListenerConfig(
109 port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
114 port=manhole,
115 bind_addresses=["127.0.0.1"],
116 type="manhole",
110117 )
111118 )
112119
4141 do_sig_check: bool = True,
4242 do_size_check: bool = True,
4343 ) -> None:
44 """ Checks if this event is correctly authed.
44 """Checks if this event is correctly authed.
4545
4646 Args:
4747 room_version_obj: the version of the room
422422
423423
424424 def check_redaction(
425 room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
425 room_version_obj: RoomVersion,
426 event: EventBase,
427 auth_events: StateMap[EventBase],
426428 ) -> bool:
427429 """Check whether the event sender is allowed to redact the target event.
428430
458460
459461
460462 def _check_power_levels(
461 room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
463 room_version_obj: RoomVersion,
464 event: EventBase,
465 auth_events: StateMap[EventBase],
462466 ) -> None:
463467 user_list = event.content.get("users", {})
464468 # Validate users
9797 return self._state_key is not None
9898
9999 async def build(
100 self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
100 self,
101 prev_event_ids: List[str],
102 auth_event_ids: Optional[List[str]],
101103 ) -> EventBase:
102104 """Transform into a fully signed and hashed event
103105
340340
341341
342342 def _decode_state_dict(input):
343 """Decodes a state dict encoded using `_encode_state_dict` above
344 """
343 """Decodes a state dict encoded using `_encode_state_dict` above"""
345344 if input is None:
346345 return None
347346
1616 import inspect
1717 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1818
19 from synapse.rest.media.v1._base import FileInfo
20 from synapse.rest.media.v1.media_storage import ReadableFileWrapper
1921 from synapse.spam_checker_api import RegistrationBehaviour
2022 from synapse.types import Collection
2123 from synapse.util.async_helpers import maybe_awaitable
213215 return behaviour
214216
215217 return RegistrationBehaviour.ALLOW
218
219 async def check_media_file_for_spam(
220 self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
221 ) -> bool:
222 """Checks if a piece of newly uploaded media should be blocked.
223
224 This will be called for local uploads, downloads of remote media, each
225 thumbnail generated for those, and web pages/images used for URL
226 previews.
227
228 Note that care should be taken to not do blocking IO operations in the
229 main thread. For example, to get the contents of a file a module
230 should do::
231
232 async def check_media_file_for_spam(
233 self, file: ReadableFileWrapper, file_info: FileInfo
234 ) -> bool:
235 buffer = BytesIO()
236 await file.write_chunks_to(buffer.write)
237
238 if buffer.getvalue() == b"Hello World":
239 return True
240
241 return False
242
243
244 Args:
245 file: An object that allows reading the contents of the media.
246 file_info: Metadata about the file.
247
248 Returns:
249 True if the media should be blocked or False if it should be
250 allowed.
251 """
252
253 for spam_checker in self.spam_checkers:
254 # For backwards compatibility, only run if the method exists on the
255 # spam checker
256 checker = getattr(spam_checker, "check_media_file_for_spam", None)
257 if checker:
258 spam = await maybe_awaitable(checker(file_wrapper, file_info))
259 if spam:
260 return True
261
262 return False
3939
4040 if module is not None:
4141 self.third_party_rules = module(
42 config=config, module_api=hs.get_module_api(),
42 config=config,
43 module_api=hs.get_module_api(),
4344 )
4445
4546 async def check_event_allowed(
3333
3434
3535 def prune_event(event: EventBase) -> EventBase:
36 """ Returns a pruned version of the given event, which removes all keys we
36 """Returns a pruned version of the given event, which removes all keys we
3737 don't know about or think could potentially be dodgy.
3838
3939 This is used when we "redact" an event. We want to remove all fields that
749749 return resp[1]
750750
751751 async def send_invite(
752 self, destination: str, room_id: str, event_id: str, pdu: EventBase,
752 self,
753 destination: str,
754 room_id: str,
755 event_id: str,
756 pdu: EventBase,
753757 ) -> EventBase:
754758 room_version = await self.store.get_room_version(room_id)
755759
8484 )
8585
8686 pdu_process_time = Histogram(
87 "synapse_federation_server_pdu_process_time", "Time taken to process an event",
87 "synapse_federation_server_pdu_process_time",
88 "Time taken to process an event",
8889 )
8990
9091
203204 async def _handle_incoming_transaction(
204205 self, origin: str, transaction: Transaction, request_time: int
205206 ) -> Tuple[int, Dict[str, Any]]:
206 """ Process an incoming transaction and return the HTTP response
207 """Process an incoming transaction and return the HTTP response
207208
208209 Args:
209210 origin: the server making the request
372373 return pdu_results
373374
374375 async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
375 """Process the EDUs in a received transaction.
376 """
376 """Process the EDUs in a received transaction."""
377377
378378 async def _process_edu(edu_dict):
379379 received_edus_counter.inc()
436436 raise AuthError(403, "Host not in room.")
437437
438438 resp = await self._state_ids_resp_cache.wrap(
439 (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id,
439 (room_id, event_id),
440 self._on_state_ids_request_compute,
441 room_id,
442 event_id,
440443 )
441444
442445 return 200, resp
678681 )
679682
680683 async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
681 """ Process a PDU received in a federation /send/ transaction.
684 """Process a PDU received in a federation /send/ transaction.
682685
683686 If the event is invalid, then this method throws a FederationError.
684687 (The error will then be logged and sent back to the sender (which
905908 self.query_handlers[query_type] = handler
906909
907910 def register_instance_for_edu(self, edu_type: str, instance_name: str):
908 """Register that the EDU handler is on a different instance than master.
909 """
911 """Register that the EDU handler is on a different instance than master."""
910912 self._edu_type_to_instance[edu_type] = [instance_name]
911913
912914 def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
913 """Register that the EDU handler is on multiple instances.
914 """
915 """Register that the EDU handler is on multiple instances."""
915916 self._edu_type_to_instance[edu_type] = instance_names
916917
917918 async def on_edu(self, edu_type: str, origin: str, content: dict):
2929
3030
3131 class TransactionActions:
32 """ Defines persistence actions that relate to handling Transactions.
33 """
32 """Defines persistence actions that relate to handling Transactions."""
3433
3534 def __init__(self, datastore):
3635 self.store = datastore
5655 async def set_response(
5756 self, origin: str, transaction: Transaction, code: int, response: JsonDict
5857 ) -> None:
59 """Persist how we responded to a transaction.
60 """
58 """Persist how we responded to a transaction."""
6159 transaction_id = transaction.transaction_id # type: ignore
6260 if not transaction_id:
6361 raise RuntimeError("Cannot persist a transaction with no transaction_id")
467467
468468
469469 class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
470 """Streams EDUs that don't have keys. See KeyedEduRow
471 """
470 """Streams EDUs that don't have keys. See KeyedEduRow"""
472471
473472 TypeId = "e"
474473
518517 # them into the appropriate collection and then send them off.
519518
520519 buff = ParsedFederationStreamData(
521 presence=[], presence_destinations=[], keyed_edus={}, edus={},
520 presence=[],
521 presence_destinations=[],
522 keyed_edus={},
523 edus={},
522524 )
523525
524526 # Parse the rows in the stream and add to the buffer
327327 # to allow us to perform catch-up later on if the remote is unreachable
328328 # for a while.
329329 await self.store.store_destination_rooms_entries(
330 destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
330 destinations,
331 pdu.room_id,
332 pdu.internal_metadata.stream_ordering,
331333 )
332334
333335 for destination in destinations:
474476 self, states: List[UserPresenceState], destinations: List[str]
475477 ) -> None:
476478 """Send the given presence states to the given destinations.
477 destinations (list[str])
479 destinations (list[str])
478480 """
479481
480482 if not states or not self.hs.config.use_presence:
615617 last_processed = None # type: Optional[str]
616618
617619 while True:
618 destinations_to_wake = await self.store.get_catch_up_outstanding_destinations(
619 last_processed
620 destinations_to_wake = (
621 await self.store.get_catch_up_outstanding_destinations(last_processed)
620622 )
621623
622624 if not destinations_to_wake:
8484 # processing. We have a guard in `attempt_new_transaction` that
8585 # ensure we don't start sending stuff.
8686 logger.error(
87 "Create a per destination queue for %s on wrong worker", destination,
87 "Create a per destination queue for %s on wrong worker",
88 destination,
8889 )
8990 self._should_send_on_this_instance = False
9091
439440
440441 if first_catch_up_check:
441442 # first catchup so get last_successful_stream_ordering from database
442 self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering(
443 self._destination
443 self._last_successful_stream_ordering = (
444 await self._store.get_destination_last_successful_stream_ordering(
445 self._destination
446 )
444447 )
445448
446449 if self._last_successful_stream_ordering is None:
456459 # get at most 50 catchup room/PDUs
457460 while True:
458461 event_ids = await self._store.get_catch_up_room_event_ids(
459 self._destination, self._last_successful_stream_ordering,
462 self._destination,
463 self._last_successful_stream_ordering,
460464 )
461465
462466 if not event_ids:
6464
6565 @measure_func("_send_new_transaction")
6666 async def send_new_transaction(
67 self, destination: str, pdus: List[EventBase], edus: List[Edu],
67 self,
68 destination: str,
69 pdus: List[EventBase],
70 edus: List[Edu],
6871 ) -> bool:
6972 """
7073 Args:
3838
3939 @log_function
4040 def get_room_state_ids(self, destination, room_id, event_id):
41 """ Requests all state for a given room from the given server at the
41 """Requests all state for a given room from the given server at the
4242 given event. Returns the state's event_id's
4343
4444 Args:
6262
6363 @log_function
6464 def get_event(self, destination, event_id, timeout=None):
65 """ Requests the pdu with give id and origin from the given server.
65 """Requests the pdu with give id and origin from the given server.
6666
6767 Args:
6868 destination (str): The host name of the remote homeserver we want
8383
8484 @log_function
8585 def backfill(self, destination, room_id, event_tuples, limit):
86 """ Requests `limit` previous PDUs in a given context before list of
86 """Requests `limit` previous PDUs in a given context before list of
8787 PDUs.
8888
8989 Args:
117117
118118 @log_function
119119 async def send_transaction(self, transaction, json_data_callback=None):
120 """ Sends the given Transaction to its destination
120 """Sends the given Transaction to its destination
121121
122122 Args:
123123 transaction (Transaction)
550550
551551 @log_function
552552 def get_group_profile(self, destination, group_id, requester_user_id):
553 """Get a group profile
554 """
553 """Get a group profile"""
555554 path = _create_v1_path("/groups/%s/profile", group_id)
556555
557556 return self.client.get_json(
583582
584583 @log_function
585584 def get_group_summary(self, destination, group_id, requester_user_id):
586 """Get a group summary
587 """
585 """Get a group summary"""
588586 path = _create_v1_path("/groups/%s/summary", group_id)
589587
590588 return self.client.get_json(
596594
597595 @log_function
598596 def get_rooms_in_group(self, destination, group_id, requester_user_id):
599 """Get all rooms in a group
600 """
597 """Get all rooms in a group"""
601598 path = _create_v1_path("/groups/%s/rooms", group_id)
602599
603600 return self.client.get_json(
610607 def add_room_to_group(
611608 self, destination, group_id, requester_user_id, room_id, content
612609 ):
613 """Add a room to a group
614 """
610 """Add a room to a group"""
615611 path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
616612
617613 return self.client.post_json(
625621 def update_room_in_group(
626622 self, destination, group_id, requester_user_id, room_id, config_key, content
627623 ):
628 """Update room in group
629 """
624 """Update room in group"""
630625 path = _create_v1_path(
631626 "/groups/%s/room/%s/config/%s", group_id, room_id, config_key
632627 )
640635 )
641636
642637 def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
643 """Remove a room from a group
644 """
638 """Remove a room from a group"""
645639 path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
646640
647641 return self.client.delete_json(
653647
654648 @log_function
655649 def get_users_in_group(self, destination, group_id, requester_user_id):
656 """Get users in a group
657 """
650 """Get users in a group"""
658651 path = _create_v1_path("/groups/%s/users", group_id)
659652
660653 return self.client.get_json(
666659
667660 @log_function
668661 def get_invited_users_in_group(self, destination, group_id, requester_user_id):
669 """Get users that have been invited to a group
670 """
662 """Get users that have been invited to a group"""
671663 path = _create_v1_path("/groups/%s/invited_users", group_id)
672664
673665 return self.client.get_json(
679671
680672 @log_function
681673 def accept_group_invite(self, destination, group_id, user_id, content):
682 """Accept a group invite
683 """
674 """Accept a group invite"""
684675 path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
685676
686677 return self.client.post_json(
689680
690681 @log_function
691682 def join_group(self, destination, group_id, user_id, content):
692 """Attempts to join a group
693 """
683 """Attempts to join a group"""
694684 path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
695685
696686 return self.client.post_json(
701691 def invite_to_group(
702692 self, destination, group_id, user_id, requester_user_id, content
703693 ):
704 """Invite a user to a group
705 """
694 """Invite a user to a group"""
706695 path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
707696
708697 return self.client.post_json(
729718 def remove_user_from_group(
730719 self, destination, group_id, requester_user_id, user_id, content
731720 ):
732 """Remove a user from a group
733 """
721 """Remove a user from a group"""
734722 path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
735723
736724 return self.client.post_json(
771759 def update_group_summary_room(
772760 self, destination, group_id, user_id, room_id, category_id, content
773761 ):
774 """Update a room entry in a group summary
775 """
762 """Update a room entry in a group summary"""
776763 if category_id:
777764 path = _create_v1_path(
778765 "/groups/%s/summary/categories/%s/rooms/%s",
795782 def delete_group_summary_room(
796783 self, destination, group_id, user_id, room_id, category_id
797784 ):
798 """Delete a room entry in a group summary
799 """
785 """Delete a room entry in a group summary"""
800786 if category_id:
801787 path = _create_v1_path(
802788 "/groups/%s/summary/categories/%s/rooms/%s",
816802
817803 @log_function
818804 def get_group_categories(self, destination, group_id, requester_user_id):
819 """Get all categories in a group
820 """
805 """Get all categories in a group"""
821806 path = _create_v1_path("/groups/%s/categories", group_id)
822807
823808 return self.client.get_json(
829814
830815 @log_function
831816 def get_group_category(self, destination, group_id, requester_user_id, category_id):
832 """Get category info in a group
833 """
817 """Get category info in a group"""
834818 path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
835819
836820 return self.client.get_json(
844828 def update_group_category(
845829 self, destination, group_id, requester_user_id, category_id, content
846830 ):
847 """Update a category in a group
848 """
831 """Update a category in a group"""
849832 path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
850833
851834 return self.client.post_json(
860843 def delete_group_category(
861844 self, destination, group_id, requester_user_id, category_id
862845 ):
863 """Delete a category in a group
864 """
846 """Delete a category in a group"""
865847 path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
866848
867849 return self.client.delete_json(
873855
874856 @log_function
875857 def get_group_roles(self, destination, group_id, requester_user_id):
876 """Get all roles in a group
877 """
858 """Get all roles in a group"""
878859 path = _create_v1_path("/groups/%s/roles", group_id)
879860
880861 return self.client.get_json(
886867
887868 @log_function
888869 def get_group_role(self, destination, group_id, requester_user_id, role_id):
889 """Get a roles info
890 """
870 """Get a roles info"""
891871 path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
892872
893873 return self.client.get_json(
901881 def update_group_role(
902882 self, destination, group_id, requester_user_id, role_id, content
903883 ):
904 """Update a role in a group
905 """
884 """Update a role in a group"""
906885 path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
907886
908887 return self.client.post_json(
915894
916895 @log_function
917896 def delete_group_role(self, destination, group_id, requester_user_id, role_id):
918 """Delete a role in a group
919 """
897 """Delete a role in a group"""
920898 path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
921899
922900 return self.client.delete_json(
930908 def update_group_summary_user(
931909 self, destination, group_id, requester_user_id, user_id, role_id, content
932910 ):
933 """Update a users entry in a group
934 """
911 """Update a users entry in a group"""
935912 if role_id:
936913 path = _create_v1_path(
937914 "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
949926
950927 @log_function
951928 def set_group_join_policy(self, destination, group_id, requester_user_id, content):
952 """Sets the join policy for a group
953 """
929 """Sets the join policy for a group"""
954930 path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
955931
956932 return self.client.put_json(
965941 def delete_group_summary_user(
966942 self, destination, group_id, requester_user_id, user_id, role_id
967943 ):
968 """Delete a users entry in a group
969 """
944 """Delete a users entry in a group"""
970945 if role_id:
971946 path = _create_v1_path(
972947 "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
982957 )
983958
984959 def bulk_get_publicised_groups(self, destination, user_ids):
985 """Get the groups a list of users are publicising
986 """
960 """Get the groups a list of users are publicising"""
987961
988962 path = _create_v1_path("/get_groups_publicised")
989963
2020 from typing import Optional, Tuple, Type
2121
2222 import synapse
23 from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH
2324 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
2425 from synapse.api.room_versions import RoomVersions
2526 from synapse.api.urls import (
363364 continue
364365
365366 server.register_paths(
366 method, (pattern,), self._wrap(code), self.__class__.__name__,
367 method,
368 (pattern,),
369 self._wrap(code),
370 self.__class__.__name__,
367371 )
368372
369373
380384
381385 # This is when someone is trying to send us a bunch of data.
382386 async def on_PUT(self, origin, content, query, transaction_id):
383 """ Called on PUT /send/<transaction_id>/
387 """Called on PUT /send/<transaction_id>/
384388
385389 Args:
386390 request (twisted.web.http.Request): The HTTP request.
854858
855859
856860 class FederationGroupsProfileServlet(BaseFederationServlet):
857 """Get/set the basic profile of a group on behalf of a user
858 """
861 """Get/set the basic profile of a group on behalf of a user"""
859862
860863 PATH = "/groups/(?P<group_id>[^/]*)/profile"
861864
894897
895898
896899 class FederationGroupsRoomsServlet(BaseFederationServlet):
897 """Get the rooms in a group on behalf of a user
898 """
900 """Get the rooms in a group on behalf of a user"""
899901
900902 PATH = "/groups/(?P<group_id>[^/]*)/rooms"
901903
910912
911913
912914 class FederationGroupsAddRoomsServlet(BaseFederationServlet):
913 """Add/remove room from group
914 """
915 """Add/remove room from group"""
915916
916917 PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
917918
939940
940941
941942 class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
942 """Update room config in group
943 """
943 """Update room config in group"""
944944
945945 PATH = (
946946 "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
960960
961961
962962 class FederationGroupsUsersServlet(BaseFederationServlet):
963 """Get the users in a group on behalf of a user
964 """
963 """Get the users in a group on behalf of a user"""
965964
966965 PATH = "/groups/(?P<group_id>[^/]*)/users"
967966
976975
977976
978977 class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
979 """Get the users that have been invited to a group
980 """
978 """Get the users that have been invited to a group"""
981979
982980 PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
983981
994992
995993
996994 class FederationGroupsInviteServlet(BaseFederationServlet):
997 """Ask a group server to invite someone to the group
998 """
995 """Ask a group server to invite someone to the group"""
999996
1000997 PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
1001998
10121009
10131010
10141011 class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
1015 """Accept an invitation from the group server
1016 """
1012 """Accept an invitation from the group server"""
10171013
10181014 PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
10191015
10271023
10281024
10291025 class FederationGroupsJoinServlet(BaseFederationServlet):
1030 """Attempt to join a group
1031 """
1026 """Attempt to join a group"""
10321027
10331028 PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
10341029
10421037
10431038
10441039 class FederationGroupsRemoveUserServlet(BaseFederationServlet):
1045 """Leave or kick a user from the group
1046 """
1040 """Leave or kick a user from the group"""
10471041
10481042 PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
10491043
10601054
10611055
10621056 class FederationGroupsLocalInviteServlet(BaseFederationServlet):
1063 """A group server has invited a local user
1064 """
1057 """A group server has invited a local user"""
10651058
10661059 PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
10671060
10751068
10761069
10771070 class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
1078 """A group server has removed a local user
1079 """
1071 """A group server has removed a local user"""
10801072
10811073 PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
10821074
10921084
10931085
10941086 class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
1095 """A group or user's server renews their attestation
1096 """
1087 """A group or user's server renews their attestation"""
10971088
10981089 PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
10991090
11271118 raise SynapseError(403, "requester_user_id doesn't match origin")
11281119
11291120 if category_id == "":
1130 raise SynapseError(400, "category_id cannot be empty string")
1121 raise SynapseError(
1122 400, "category_id cannot be empty string", Codes.INVALID_PARAM
1123 )
1124
1125 if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
1126 raise SynapseError(
1127 400,
1128 "category_id may not be longer than %s characters"
1129 % (MAX_GROUP_CATEGORYID_LENGTH,),
1130 Codes.INVALID_PARAM,
1131 )
11311132
11321133 resp = await self.handler.update_group_summary_room(
11331134 group_id,
11551156
11561157
11571158 class FederationGroupsCategoriesServlet(BaseFederationServlet):
1158 """Get all categories for a group
1159 """
1159 """Get all categories for a group"""
11601160
11611161 PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
11621162
11711171
11721172
11731173 class FederationGroupsCategoryServlet(BaseFederationServlet):
1174 """Add/remove/get a category in a group
1175 """
1174 """Add/remove/get a category in a group"""
11761175
11771176 PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
11781177
11951194 if category_id == "":
11961195 raise SynapseError(400, "category_id cannot be empty string")
11971196
1197 if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
1198 raise SynapseError(
1199 400,
1200 "category_id may not be longer than %s characters"
1201 % (MAX_GROUP_CATEGORYID_LENGTH,),
1202 Codes.INVALID_PARAM,
1203 )
1204
11981205 resp = await self.handler.upsert_group_category(
11991206 group_id, requester_user_id, category_id, content
12001207 )
12171224
12181225
12191226 class FederationGroupsRolesServlet(BaseFederationServlet):
1220 """Get roles in a group
1221 """
1227 """Get roles in a group"""
12221228
12231229 PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
12241230
12331239
12341240
12351241 class FederationGroupsRoleServlet(BaseFederationServlet):
1236 """Add/remove/get a role in a group
1237 """
1242 """Add/remove/get a role in a group"""
12381243
12391244 PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
12401245
12531258 raise SynapseError(403, "requester_user_id doesn't match origin")
12541259
12551260 if role_id == "":
1256 raise SynapseError(400, "role_id cannot be empty string")
1261 raise SynapseError(
1262 400, "role_id cannot be empty string", Codes.INVALID_PARAM
1263 )
1264
1265 if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
1266 raise SynapseError(
1267 400,
1268 "role_id may not be longer than %s characters"
1269 % (MAX_GROUP_ROLEID_LENGTH,),
1270 Codes.INVALID_PARAM,
1271 )
12571272
12581273 resp = await self.handler.update_group_role(
12591274 group_id, requester_user_id, role_id, content
12971312
12981313 if role_id == "":
12991314 raise SynapseError(400, "role_id cannot be empty string")
1315
1316 if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
1317 raise SynapseError(
1318 400,
1319 "role_id may not be longer than %s characters"
1320 % (MAX_GROUP_ROLEID_LENGTH,),
1321 Codes.INVALID_PARAM,
1322 )
13001323
13011324 resp = await self.handler.update_group_summary_user(
13021325 group_id,
13241347
13251348
13261349 class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
1327 """Get roles in a group
1328 """
1350 """Get roles in a group"""
13291351
13301352 PATH = "/get_groups_publicised"
13311353
13381360
13391361
13401362 class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
1341 """Sets whether a group is joinable without an invite or knock
1342 """
1363 """Sets whether a group is joinable without an invite or knock"""
13431364
13441365 PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
13451366
2828
2929 @attr.s(slots=True)
3030 class Edu(JsonEncodedObject):
31 """ An Edu represents a piece of data sent from one homeserver to another.
31 """An Edu represents a piece of data sent from one homeserver to another.
3232
3333 In comparison to Pdus, Edus are not persisted for a long time on disk, are
3434 not meaningful beyond a given pair of homeservers, and don't have an
6262
6363
6464 class Transaction(JsonEncodedObject):
65 """ A transaction is a list of Pdus and Edus to be sent to a remote home
65 """A transaction is a list of Pdus and Edus to be sent to a remote home
6666 server with some extra metadata.
6767
6868 Example transaction::
9898 ]
9999
100100 def __init__(self, transaction_id=None, pdus=[], **kwargs):
101 """ If we include a list of pdus then we decode then as PDU's
101 """If we include a list of pdus then we decode then as PDU's
102102 automatically.
103103 """
104104
110110
111111 @staticmethod
112112 def create_new(pdus, **kwargs):
113 """ Used to create a new transaction. Will auto fill out
113 """Used to create a new transaction. Will auto fill out
114114 transaction_id and origin_server_ts keys.
115115 """
116116 if "origin_server_ts" not in kwargs:
3636
3737 import logging
3838 import random
39 from typing import Tuple
39 from typing import TYPE_CHECKING, Optional, Tuple
4040
4141 from signedjson.sign import sign_json
4242
4343 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
4444 from synapse.metrics.background_process_metrics import run_as_background_process
45 from synapse.types import get_domain_from_id
45 from synapse.types import JsonDict, get_domain_from_id
46
47 if TYPE_CHECKING:
48 from synapse.app.homeserver import HomeServer
4649
4750 logger = logging.getLogger(__name__)
4851
6063
6164
6265 class GroupAttestationSigning:
63 """Creates and verifies group attestations.
64 """
65
66 def __init__(self, hs):
66 """Creates and verifies group attestations."""
67
68 def __init__(self, hs: "HomeServer"):
6769 self.keyring = hs.get_keyring()
6870 self.clock = hs.get_clock()
6971 self.server_name = hs.hostname
7072 self.signing_key = hs.signing_key
7173
7274 async def verify_attestation(
73 self, attestation, group_id, user_id, server_name=None
74 ):
75 self,
76 attestation: JsonDict,
77 group_id: str,
78 user_id: str,
79 server_name: Optional[str] = None,
80 ) -> None:
7581 """Verifies that the given attestation matches the given parameters.
7682
7783 An optional server_name can be supplied to explicitly set which server's
100106 if valid_until_ms < now:
101107 raise SynapseError(400, "Attestation expired")
102108
109 assert server_name is not None
103110 await self.keyring.verify_json_for_server(
104111 server_name, attestation, now, "Group attestation"
105112 )
106113
107 def create_attestation(self, group_id, user_id):
114 def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
108115 """Create an attestation for the group_id and user_id with default
109116 validity length.
110117 """
111 validity_period = DEFAULT_ATTESTATION_LENGTH_MS
112 validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
118 validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform(
119 *DEFAULT_ATTESTATION_JITTER
120 )
113121 valid_until_ms = int(self.clock.time_msec() + validity_period)
114122
115123 return sign_json(
124132
125133
126134 class GroupAttestionRenewer:
127 """Responsible for sending and receiving attestation updates.
128 """
129
130 def __init__(self, hs):
135 """Responsible for sending and receiving attestation updates."""
136
137 def __init__(self, hs: "HomeServer"):
131138 self.clock = hs.get_clock()
132139 self.store = hs.get_datastore()
133140 self.assestations = hs.get_groups_attestation_signing()
140147 self._start_renew_attestations, 30 * 60 * 1000
141148 )
142149
143 async def on_renew_attestation(self, group_id, user_id, content):
144 """When a remote updates an attestation
145 """
150 async def on_renew_attestation(
151 self, group_id: str, user_id: str, content: JsonDict
152 ) -> JsonDict:
153 """When a remote updates an attestation"""
146154 attestation = content["attestation"]
147155
148156 if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
156164
157165 return {}
158166
159 def _start_renew_attestations(self):
167 def _start_renew_attestations(self) -> None:
160168 return run_as_background_process("renew_attestations", self._renew_attestations)
161169
162 async def _renew_attestations(self):
163 """Called periodically to check if we need to update any of our attestations
164 """
170 async def _renew_attestations(self) -> None:
171 """Called periodically to check if we need to update any of our attestations"""
165172
166173 now = self.clock.time_msec()
167174
169176 now + UPDATE_ATTESTATION_TIME_MS
170177 )
171178
172 async def _renew_attestation(group_user: Tuple[str, str]):
179 async def _renew_attestation(group_user: Tuple[str, str]) -> None:
173180 group_id, user_id = group_user
174181 try:
175182 if not self.is_mine_id(group_id):
1515 # limitations under the License.
1616
1717 import logging
18 from typing import TYPE_CHECKING, Optional
1819
1920 from synapse.api.errors import Codes, SynapseError
20 from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
21 from synapse.handlers.groups_local import GroupsLocalHandler
22 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
23 from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id
2124 from synapse.util.async_helpers import concurrently_execute
25
26 if TYPE_CHECKING:
27 from synapse.app.homeserver import HomeServer
2228
2329 logger = logging.getLogger(__name__)
2430
3137 # TODO: Flairs
3238
3339
40 # Note that the maximum lengths are somewhat arbitrary.
41 MAX_SHORT_DESC_LEN = 1000
42 MAX_LONG_DESC_LEN = 10000
43
44
3445 class GroupsServerWorkerHandler:
35 def __init__(self, hs):
46 def __init__(self, hs: "HomeServer"):
3647 self.hs = hs
3748 self.store = hs.get_datastore()
3849 self.room_list_handler = hs.get_room_list_handler()
4758 self.profile_handler = hs.get_profile_handler()
4859
4960 async def check_group_is_ours(
50 self, group_id, requester_user_id, and_exists=False, and_is_admin=None
51 ):
61 self,
62 group_id: str,
63 requester_user_id: str,
64 and_exists: bool = False,
65 and_is_admin: Optional[str] = None,
66 ) -> Optional[dict]:
5267 """Check that the group is ours, and optionally if it exists.
5368
5469 If group does exist then return group.
5570
5671 Args:
57 group_id (str)
58 and_exists (bool): whether to also check if group exists
59 and_is_admin (str): whether to also check if given str is a user_id
72 group_id: The group ID to check.
73 requester_user_id: The user ID of the requester.
74 and_exists: whether to also check if group exists
75 and_is_admin: whether to also check if given str is a user_id
6076 that is an admin
6177 """
6278 if not self.is_mine_id(group_id):
7995
8096 return group
8197
82 async def get_group_summary(self, group_id, requester_user_id):
98 async def get_group_summary(
99 self, group_id: str, requester_user_id: str
100 ) -> JsonDict:
83101 """Get the summary for a group as seen by requester_user_id.
84102
85103 The group summary consists of the profile of the room, and a curated
112130 entry = await self.room_list_handler.generate_room_entry(
113131 room_id, len(joined_users), with_alias=False, allow_private=True
114132 )
133 if entry is None:
134 continue
115135 entry = dict(entry) # so we don't change what's cached
116136 entry.pop("room_id", None)
117137
119139
120140 rooms.sort(key=lambda e: e.get("order", 0))
121141
122 for entry in users:
123 user_id = entry["user_id"]
142 for user in users:
143 user_id = user["user_id"]
124144
125145 if not self.is_mine_id(requester_user_id):
126146 attestation = await self.store.get_remote_attestation(group_id, user_id)
127147 if not attestation:
128148 continue
129149
130 entry["attestation"] = attestation
150 user["attestation"] = attestation
131151 else:
132 entry["attestation"] = self.attestations.create_attestation(
152 user["attestation"] = self.attestations.create_attestation(
133153 group_id, user_id
134154 )
135155
136156 user_profile = await self.profile_handler.get_profile_from_cache(user_id)
137 entry.update(user_profile)
157 user.update(user_profile)
138158
139159 users.sort(key=lambda e: e.get("order", 0))
140160
157177 "user": membership_info,
158178 }
159179
160 async def get_group_categories(self, group_id, requester_user_id):
161 """Get all categories in a group (as seen by user)
162 """
180 async def get_group_categories(
181 self, group_id: str, requester_user_id: str
182 ) -> JsonDict:
183 """Get all categories in a group (as seen by user)"""
163184 await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
164185
165186 categories = await self.store.get_group_categories(group_id=group_id)
166187 return {"categories": categories}
167188
168 async def get_group_category(self, group_id, requester_user_id, category_id):
169 """Get a specific category in a group (as seen by user)
170 """
189 async def get_group_category(
190 self, group_id: str, requester_user_id: str, category_id: str
191 ) -> JsonDict:
192 """Get a specific category in a group (as seen by user)"""
171193 await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
172194
173 res = await self.store.get_group_category(
195 return await self.store.get_group_category(
174196 group_id=group_id, category_id=category_id
175197 )
176198
177 logger.info("group %s", res)
178
179 return res
180
181 async def get_group_roles(self, group_id, requester_user_id):
182 """Get all roles in a group (as seen by user)
183 """
199 async def get_group_roles(self, group_id: str, requester_user_id: str) -> JsonDict:
200 """Get all roles in a group (as seen by user)"""
184201 await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
185202
186203 roles = await self.store.get_group_roles(group_id=group_id)
187204 return {"roles": roles}
188205
189 async def get_group_role(self, group_id, requester_user_id, role_id):
190 """Get a specific role in a group (as seen by user)
191 """
206 async def get_group_role(
207 self, group_id: str, requester_user_id: str, role_id: str
208 ) -> JsonDict:
209 """Get a specific role in a group (as seen by user)"""
192210 await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
193211
194 res = await self.store.get_group_role(group_id=group_id, role_id=role_id)
195 return res
196
197 async def get_group_profile(self, group_id, requester_user_id):
198 """Get the group profile as seen by requester_user_id
199 """
212 return await self.store.get_group_role(group_id=group_id, role_id=role_id)
213
214 async def get_group_profile(
215 self, group_id: str, requester_user_id: str
216 ) -> JsonDict:
217 """Get the group profile as seen by requester_user_id"""
200218
201219 await self.check_group_is_ours(group_id, requester_user_id)
202220
217235 else:
218236 raise SynapseError(404, "Unknown group")
219237
220 async def get_users_in_group(self, group_id, requester_user_id):
238 async def get_users_in_group(
239 self, group_id: str, requester_user_id: str
240 ) -> JsonDict:
221241 """Get the users in group as seen by requester_user_id.
222242
223243 The ordering is arbitrary at the moment
266286
267287 return {"chunk": chunk, "total_user_count_estimate": len(user_results)}
268288
269 async def get_invited_users_in_group(self, group_id, requester_user_id):
289 async def get_invited_users_in_group(
290 self, group_id: str, requester_user_id: str
291 ) -> JsonDict:
270292 """Get the users that have been invited to a group as seen by requester_user_id.
271293
272294 The ordering is arbitrary at the moment
296318
297319 return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
298320
299 async def get_rooms_in_group(self, group_id, requester_user_id):
321 async def get_rooms_in_group(
322 self, group_id: str, requester_user_id: str
323 ) -> JsonDict:
300324 """Get the rooms in group as seen by requester_user_id
301325
302326 This returns rooms in order of decreasing number of joined users
334358
335359
336360 class GroupsServerHandler(GroupsServerWorkerHandler):
337 def __init__(self, hs):
361 def __init__(self, hs: "HomeServer"):
338362 super().__init__(hs)
339363
340364 # Ensure attestations get renewed
341365 hs.get_groups_attestation_renewer()
342366
343367 async def update_group_summary_room(
344 self, group_id, requester_user_id, room_id, category_id, content
345 ):
346 """Add/update a room to the group summary
347 """
368 self,
369 group_id: str,
370 requester_user_id: str,
371 room_id: str,
372 category_id: str,
373 content: JsonDict,
374 ) -> JsonDict:
375 """Add/update a room to the group summary"""
348376 await self.check_group_is_ours(
349377 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
350378 )
366394 return {}
367395
368396 async def delete_group_summary_room(
369 self, group_id, requester_user_id, room_id, category_id
370 ):
371 """Remove a room from the summary
372 """
397 self, group_id: str, requester_user_id: str, room_id: str, category_id: str
398 ) -> JsonDict:
399 """Remove a room from the summary"""
373400 await self.check_group_is_ours(
374401 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
375402 )
380407
381408 return {}
382409
383 async def set_group_join_policy(self, group_id, requester_user_id, content):
410 async def set_group_join_policy(
411 self, group_id: str, requester_user_id: str, content: JsonDict
412 ) -> JsonDict:
384413 """Sets the group join policy.
385414
386415 Currently supported policies are:
400429 return {}
401430
402431 async def update_group_category(
403 self, group_id, requester_user_id, category_id, content
404 ):
405 """Add/Update a group category
406 """
432 self, group_id: str, requester_user_id: str, category_id: str, content: JsonDict
433 ) -> JsonDict:
434 """Add/Update a group category"""
407435 await self.check_group_is_ours(
408436 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
409437 )
420448
421449 return {}
422450
423 async def delete_group_category(self, group_id, requester_user_id, category_id):
424 """Delete a group category
425 """
451 async def delete_group_category(
452 self, group_id: str, requester_user_id: str, category_id: str
453 ) -> JsonDict:
454 """Delete a group category"""
426455 await self.check_group_is_ours(
427456 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
428457 )
433462
434463 return {}
435464
436 async def update_group_role(self, group_id, requester_user_id, role_id, content):
437 """Add/update a role in a group
438 """
465 async def update_group_role(
466 self, group_id: str, requester_user_id: str, role_id: str, content: JsonDict
467 ) -> JsonDict:
468 """Add/update a role in a group"""
439469 await self.check_group_is_ours(
440470 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
441471 )
450480
451481 return {}
452482
453 async def delete_group_role(self, group_id, requester_user_id, role_id):
454 """Remove role from group
455 """
483 async def delete_group_role(
484 self, group_id: str, requester_user_id: str, role_id: str
485 ) -> JsonDict:
486 """Remove role from group"""
456487 await self.check_group_is_ours(
457488 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
458489 )
462493 return {}
463494
464495 async def update_group_summary_user(
465 self, group_id, requester_user_id, user_id, role_id, content
466 ):
467 """Add/update a users entry in the group summary
468 """
496 self,
497 group_id: str,
498 requester_user_id: str,
499 user_id: str,
500 role_id: str,
501 content: JsonDict,
502 ) -> JsonDict:
503 """Add/update a users entry in the group summary"""
469504 await self.check_group_is_ours(
470505 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
471506 )
485520 return {}
486521
487522 async def delete_group_summary_user(
488 self, group_id, requester_user_id, user_id, role_id
489 ):
490 """Remove a user from the group summary
491 """
523 self, group_id: str, requester_user_id: str, user_id: str, role_id: str
524 ) -> JsonDict:
525 """Remove a user from the group summary"""
492526 await self.check_group_is_ours(
493527 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
494528 )
499533
500534 return {}
501535
502 async def update_group_profile(self, group_id, requester_user_id, content):
503 """Update the group profile
504 """
536 async def update_group_profile(
537 self, group_id: str, requester_user_id: str, content: JsonDict
538 ) -> None:
539 """Update the group profile"""
505540 await self.check_group_is_ours(
506541 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
507542 )
508543
509544 profile = {}
510 for keyname in ("name", "avatar_url", "short_description", "long_description"):
545 for keyname, max_length in (
546 ("name", MAX_DISPLAYNAME_LEN),
547 ("avatar_url", MAX_AVATAR_URL_LEN),
548 ("short_description", MAX_SHORT_DESC_LEN),
549 ("long_description", MAX_LONG_DESC_LEN),
550 ):
511551 if keyname in content:
512552 value = content[keyname]
513553 if not isinstance(value, str):
514 raise SynapseError(400, "%r value is not a string" % (keyname,))
554 raise SynapseError(
555 400,
556 "%r value is not a string" % (keyname,),
557 errcode=Codes.INVALID_PARAM,
558 )
559 if len(value) > max_length:
560 raise SynapseError(
561 400,
562 "Invalid %s parameter" % (keyname,),
563 errcode=Codes.INVALID_PARAM,
564 )
515565 profile[keyname] = value
516566
517567 await self.store.update_group_profile(group_id, profile)
518568
519 async def add_room_to_group(self, group_id, requester_user_id, room_id, content):
520 """Add room to group
521 """
569 async def add_room_to_group(
570 self, group_id: str, requester_user_id: str, room_id: str, content: JsonDict
571 ) -> JsonDict:
572 """Add room to group"""
522573 RoomID.from_string(room_id) # Ensure valid room id
523574
524575 await self.check_group_is_ours(
532583 return {}
533584
534585 async def update_room_in_group(
535 self, group_id, requester_user_id, room_id, config_key, content
536 ):
537 """Update room in group
538 """
586 self,
587 group_id: str,
588 requester_user_id: str,
589 room_id: str,
590 config_key: str,
591 content: JsonDict,
592 ) -> JsonDict:
593 """Update room in group"""
539594 RoomID.from_string(room_id) # Ensure valid room id
540595
541596 await self.check_group_is_ours(
553608
554609 return {}
555610
556 async def remove_room_from_group(self, group_id, requester_user_id, room_id):
557 """Remove room from group
558 """
611 async def remove_room_from_group(
612 self, group_id: str, requester_user_id: str, room_id: str
613 ) -> JsonDict:
614 """Remove room from group"""
559615 await self.check_group_is_ours(
560616 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
561617 )
564620
565621 return {}
566622
567 async def invite_to_group(self, group_id, user_id, requester_user_id, content):
568 """Invite user to group
569 """
623 async def invite_to_group(
624 self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
625 ) -> JsonDict:
626 """Invite user to group"""
570627
571628 group = await self.check_group_is_ours(
572629 group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
573630 )
631 if not group:
632 raise SynapseError(400, "Group does not exist", errcode=Codes.BAD_STATE)
574633
575634 # TODO: Check if user knocked
576635
593652
594653 if self.hs.is_mine_id(user_id):
595654 groups_local = self.hs.get_groups_local_handler()
655 assert isinstance(
656 groups_local, GroupsLocalHandler
657 ), "Workers cannot invites users to groups."
596658 res = await groups_local.on_invite(group_id, user_id, content)
597659 local_attestation = None
598660 else:
628690 local_attestation=local_attestation,
629691 remote_attestation=remote_attestation,
630692 )
693 return {"state": "join"}
631694 elif res["state"] == "invite":
632695 await self.store.add_group_invite(group_id, user_id)
633696 return {"state": "invite"}
636699 else:
637700 raise SynapseError(502, "Unknown state returned by HS")
638701
639 async def _add_user(self, group_id, user_id, content):
702 async def _add_user(
703 self, group_id: str, user_id: str, content: JsonDict
704 ) -> Optional[JsonDict]:
640705 """Add a user to a group based on a content dict.
641706
642707 See accept_invite, join_group.
643708 """
644709 if not self.hs.is_mine_id(user_id):
645 local_attestation = self.attestations.create_attestation(group_id, user_id)
710 local_attestation = self.attestations.create_attestation(
711 group_id, user_id
712 ) # type: Optional[JsonDict]
646713
647714 remote_attestation = content["attestation"]
648715
666733
667734 return local_attestation
668735
669 async def accept_invite(self, group_id, requester_user_id, content):
736 async def accept_invite(
737 self, group_id: str, requester_user_id: str, content: JsonDict
738 ) -> JsonDict:
670739 """User tries to accept an invite to the group.
671740
672741 This is different from them asking to join, and so should error if no
685754
686755 return {"state": "join", "attestation": local_attestation}
687756
688 async def join_group(self, group_id, requester_user_id, content):
757 async def join_group(
758 self, group_id: str, requester_user_id: str, content: JsonDict
759 ) -> JsonDict:
689760 """User tries to join the group.
690761
691762 This will error if the group requires an invite/knock to join
694765 group_info = await self.check_group_is_ours(
695766 group_id, requester_user_id, and_exists=True
696767 )
768 if not group_info:
769 raise SynapseError(404, "Group does not exist", errcode=Codes.NOT_FOUND)
697770 if group_info["join_policy"] != "open":
698771 raise SynapseError(403, "Group is not publicly joinable")
699772
701774
702775 return {"state": "join", "attestation": local_attestation}
703776
704 async def knock(self, group_id, requester_user_id, content):
705 """A user requests becoming a member of the group
706 """
707 await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
708
709 raise NotImplementedError()
710
711 async def accept_knock(self, group_id, requester_user_id, content):
712 """Accept a users knock to the room.
713
714 Errors if the user hasn't knocked, rather than inviting them.
715 """
716
717 await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
718
719 raise NotImplementedError()
720
721777 async def remove_user_from_group(
722 self, group_id, user_id, requester_user_id, content
723 ):
778 self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
779 ) -> JsonDict:
724780 """Remove a user from the group; either a user is leaving or an admin
725781 kicked them.
726782 """
742798 if is_kick:
743799 if self.hs.is_mine_id(user_id):
744800 groups_local = self.hs.get_groups_local_handler()
801 assert isinstance(
802 groups_local, GroupsLocalHandler
803 ), "Workers cannot remove users from groups."
745804 await groups_local.user_removed_from_group(group_id, user_id, {})
746805 else:
747806 await self.transport_client.remove_user_from_group_notification(
758817
759818 return {}
760819
761 async def create_group(self, group_id, requester_user_id, content):
762 group = await self.check_group_is_ours(group_id, requester_user_id)
763
820 async def create_group(
821 self, group_id: str, requester_user_id: str, content: JsonDict
822 ) -> JsonDict:
764823 logger.info("Attempting to create group with ID: %r", group_id)
765824
766825 # parsing the id into a GroupID validates it.
767826 group_id_obj = GroupID.from_string(group_id)
768827
828 group = await self.check_group_is_ours(group_id, requester_user_id)
769829 if group:
770830 raise SynapseError(400, "Group already exists")
771831
810870
811871 local_attestation = self.attestations.create_attestation(
812872 group_id, requester_user_id
813 )
873 ) # type: Optional[JsonDict]
814874 else:
815875 local_attestation = None
816876 remote_attestation = None
833893
834894 return {"group_id": group_id}
835895
836 async def delete_group(self, group_id, requester_user_id):
896 async def delete_group(self, group_id: str, requester_user_id: str) -> None:
837897 """Deletes a group, kicking out all current members.
838898
839899 Only group admins or server admins can call this request
840900
841901 Args:
842 group_id (str)
843 request_user_id (str)
844
902 group_id: The group ID to delete.
903 requester_user_id: The user requesting to delete the group.
845904 """
846905
847906 await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
864923 async def _kick_user_from_group(user_id):
865924 if self.hs.is_mine_id(user_id):
866925 groups_local = self.hs.get_groups_local_handler()
926 assert isinstance(
927 groups_local, GroupsLocalHandler
928 ), "Workers cannot kick users from groups."
867929 await groups_local.user_removed_from_group(group_id, user_id, {})
868930 else:
869931 await self.transport_client.remove_user_from_group_notification(
895957 await self.store.delete_group(group_id)
896958
897959
898 def _parse_join_policy_from_contents(content):
899 """Given a content for a request, return the specified join policy or None
900 """
960 def _parse_join_policy_from_contents(content: JsonDict) -> Optional[str]:
961 """Given a content for a request, return the specified join policy or None"""
901962
902963 join_policy_dict = content.get("m.join_policy")
903964 if join_policy_dict:
906967 return None
907968
908969
909 def _parse_join_policy_dict(join_policy_dict):
910 """Given a dict for the "m.join_policy" config return the join policy specified
911 """
970 def _parse_join_policy_dict(join_policy_dict: JsonDict) -> str:
971 """Given a dict for the "m.join_policy" config return the join policy specified"""
912972 join_policy_type = join_policy_dict.get("type")
913973 if not join_policy_type:
914974 return "invite"
918978 return join_policy_type
919979
920980
921 def _parse_visibility_from_contents(content):
981 def _parse_visibility_from_contents(content: JsonDict) -> bool:
922982 """Given a content for a request parse out whether the entity should be
923983 public or not
924984 """
932992 return is_public
933993
934994
935 def _parse_visibility_dict(visibility):
995 def _parse_visibility_dict(visibility: JsonDict) -> bool:
936996 """Given a dict for the "m.visibility" config return if the entity should
937997 be public or not
938998 """
202202
203203
204204 class ExfiltrationWriter(metaclass=abc.ABCMeta):
205 """Interface used to specify how to write exported data.
206 """
205 """Interface used to specify how to write exported data."""
207206
208207 @abc.abstractmethod
209208 def write_events(self, room_id: str, events: List[EventBase]) -> None:
210 """Write a batch of events for a room.
211 """
209 """Write a batch of events for a room."""
212210 raise NotImplementedError()
213211
214212 @abc.abstractmethod
289289 if not interested:
290290 continue
291291 presence_events, _ = await presence_source.get_new_events(
292 user=user, service=service, from_key=from_key,
292 user=user,
293 service=service,
294 from_key=from_key,
293295 )
294296 time_now = self.clock.time_msec()
295297 events.extend(
119119 # Ensure the identifier has a type
120120 if "type" not in identifier:
121121 raise SynapseError(
122 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
122 400,
123 "'identifier' dict has no key 'type'",
124 errcode=Codes.MISSING_PARAM,
123125 )
124126
125127 return identifier
350352
351353 try:
352354 result, params, session_id = await self.check_ui_auth(
353 flows, request, request_body, description, get_new_session_data,
355 flows,
356 request,
357 request_body,
358 description,
359 get_new_session_data,
354360 )
355361 except LoginError:
356362 # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
378384 return params, session_id
379385
380386 async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
381 """Get a list of the authentication types this user can use
382 """
387 """Get a list of the authentication types this user can use"""
383388
384389 ui_auth_types = set()
385390
722727 }
723728
724729 def _auth_dict_for_flows(
725 self, flows: List[List[str]], session_id: str,
730 self,
731 flows: List[List[str]],
732 session_id: str,
726733 ) -> Dict[str, Any]:
727734 public_flows = []
728735 for f in flows:
879886 return self._supported_login_types
880887
881888 async def validate_login(
882 self, login_submission: Dict[str, Any], ratelimit: bool = False,
889 self,
890 login_submission: Dict[str, Any],
891 ratelimit: bool = False,
883892 ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
884893 """Authenticates the user for the /login API
885894
10221031 raise
10231032
10241033 async def _validate_userid_login(
1025 self, username: str, login_submission: Dict[str, Any],
1034 self,
1035 username: str,
1036 login_submission: Dict[str, Any],
10261037 ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
10271038 """Helper for validate_login
10281039
14451456 # is considered OK since the newest SSO attributes should be most valid.
14461457 if extra_attributes:
14471458 self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
1448 self._clock.time_msec(), extra_attributes,
1459 self._clock.time_msec(),
1460 extra_attributes,
14491461 )
14501462
14511463 # Create a login token
14711483 # Remove the query parameters from the redirect URL to get a shorter version of
14721484 # it. This is only to display a human-readable URL in the template, but not the
14731485 # URL we redirect users to.
1474 redirect_url_no_params = client_redirect_url.split("?")[0]
1486 url_parts = urllib.parse.urlsplit(client_redirect_url)
1487
1488 if url_parts.scheme == "https":
1489 # for an https uri, just show the netloc (ie, the hostname. Specifically,
1490 # the bit between "//" and "/"; this includes any potential
1491 # "username:password@" prefix.)
1492 display_url = url_parts.netloc
1493 else:
1494 # for other uris, strip the query-params (including the login token) and
1495 # fragment.
1496 display_url = urllib.parse.urlunsplit(
1497 (url_parts.scheme, url_parts.netloc, url_parts.path, "", "")
1498 )
14751499
14761500 html = self._sso_redirect_confirm_template.render(
1477 display_url=redirect_url_no_params,
1501 display_url=display_url,
14781502 redirect_url=redirect_url,
14791503 server_name=self._server_name,
14801504 new_user=new_user,
16891713 # This might return an awaitable, if it does block the log out
16901714 # until it completes.
16911715 await maybe_awaitable(
1692 g(user_id=user_id, device_id=device_id, access_token=access_token,)
1693 )
1716 g(
1717 user_id=user_id,
1718 device_id=device_id,
1719 access_token=access_token,
1720 )
1721 )
1313 # limitations under the License.
1414 import logging
1515 import urllib.parse
16 from typing import TYPE_CHECKING, Dict, Optional
16 from typing import TYPE_CHECKING, Dict, List, Optional
1717 from xml.etree import ElementTree as ET
1818
1919 import attr
3232
3333
3434 class CasError(Exception):
35 """Used to catch errors when validating the CAS ticket.
36 """
35 """Used to catch errors when validating the CAS ticket."""
3736
3837 def __init__(self, error, error_description=None):
3938 self.error = error
4847 @attr.s(slots=True, frozen=True)
4948 class CasResponse:
5049 username = attr.ib(type=str)
51 attributes = attr.ib(type=Dict[str, Optional[str]])
50 attributes = attr.ib(type=Dict[str, List[Optional[str]]])
5251
5352
5453 class CasHandler:
9998 Returns:
10099 The URL to use as a "service" parameter.
101100 """
102 return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
101 return "%s?%s" % (
102 self._cas_service_url,
103 urllib.parse.urlencode(args),
104 )
103105
104106 async def _validate_ticket(
105107 self, ticket: str, service_args: Dict[str, str]
168170
169171 # Iterate through the nodes and pull out the user and any extra attributes.
170172 user = None
171 attributes = {}
173 attributes = {} # type: Dict[str, List[Optional[str]]]
172174 for child in root[0]:
173175 if child.tag.endswith("user"):
174176 user = child.text
181183 tag = attribute.tag
182184 if "}" in tag:
183185 tag = tag.split("}")[1]
184 attributes[tag] = attribute.text
186 attributes.setdefault(tag, []).append(attribute.text)
185187
186188 # Ensure a user was found.
187189 if user is None:
295297 # first check if we're doing a UIA
296298 if session:
297299 return await self._sso_handler.complete_sso_ui_auth_request(
298 self.idp_id, cas_response.username, session, request,
300 self.idp_id,
301 cas_response.username,
302 session,
303 request,
299304 )
300305
301306 # otherwise, we're handling a login request.
302307
303308 # Ensure that the attributes of the logged in user meet the required
304309 # attributes.
305 for required_attribute, required_value in self._cas_required_attributes.items():
306 # If required attribute was not in CAS Response - Forbidden
307 if required_attribute not in cas_response.attributes:
308 self._sso_handler.render_error(
309 request,
310 "unauthorised",
311 "You are not authorised to log in here.",
312 401,
313 )
314 return
315
316 # Also need to check value
317 if required_value is not None:
318 actual_value = cas_response.attributes[required_attribute]
319 # If required attribute value does not match expected - Forbidden
320 if required_value != actual_value:
321 self._sso_handler.render_error(
322 request,
323 "unauthorised",
324 "You are not authorised to log in here.",
325 401,
326 )
327 return
310 if not self._sso_handler.check_required_attributes(
311 request, cas_response.attributes, self._cas_required_attributes
312 ):
313 return
328314
329315 # Call the mapper to register/login the user
330316
371357 if failures:
372358 raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
373359
360 # Arbitrarily use the first attribute found.
374361 display_name = cas_response.attributes.get(
375 self._cas_displayname_attribute, None
376 )
362 self._cas_displayname_attribute, [None]
363 )[0]
377364
378365 return UserAttributes(localpart=localpart, display_name=display_name)
379366
383370 user_id = UserID(localpart, self._hostname).to_string()
384371
385372 logger.debug(
386 "Looking for existing account based on mapped %s", user_id,
373 "Looking for existing account based on mapped %s",
374 user_id,
387375 )
388376
389377 users = await self._store.get_users_by_id_case_insensitive(user_id)
195195 run_as_background_process("user_parter_loop", self._user_parter_loop)
196196
197197 async def _user_parter_loop(self) -> None:
198 """Loop that parts deactivated users from rooms
199 """
198 """Loop that parts deactivated users from rooms"""
200199 self._user_parter_running = True
201200 logger.info("Starting user parter")
202201 try:
213212 self._user_parter_running = False
214213
215214 async def _part_user(self, user_id: str) -> None:
216 """Causes the given user_id to leave all the rooms they're joined to
217 """
215 """Causes the given user_id to leave all the rooms they're joined to"""
218216 user = UserID.from_string(user_id)
219217
220218 rooms_for_user = await self.store.get_rooms_for_user(user_id)
8585
8686 @trace
8787 async def get_device(self, user_id: str, device_id: str) -> JsonDict:
88 """ Retrieve the given device
88 """Retrieve the given device
8989
9090 Args:
9191 user_id: The user to get the device from
340340
341341 @trace
342342 async def delete_device(self, user_id: str, device_id: str) -> None:
343 """ Delete the given device
343 """Delete the given device
344344
345345 Args:
346346 user_id: The user to delete the device from.
385385 await self.delete_devices(user_id, device_ids)
386386
387387 async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
388 """ Delete several devices
388 """Delete several devices
389389
390390 Args:
391391 user_id: The user to delete devices from.
416416 await self.notify_device_update(user_id, device_ids)
417417
418418 async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
419 """ Update the given device
419 """Update the given device
420420
421421 Args:
422422 user_id: The user to update devices of.
533533 device id of the dehydrated device
534534 """
535535 device_id = await self.check_device_registered(
536 user_id, None, initial_device_display_name,
536 user_id,
537 None,
538 initial_device_display_name,
537539 )
538540 old_device_id = await self.store.store_dehydrated_device(
539541 user_id, device_id, device_data
802804 try:
803805 # Try to resync the current user's devices list.
804806 result = await self.user_device_resync(
805 user_id=user_id, mark_failed_as_stale=False,
807 user_id=user_id,
808 mark_failed_as_stale=False,
806809 )
807810
808811 # user_device_resync only returns a result if it managed to
812815 # self.store.update_remote_device_list_cache).
813816 if result:
814817 logger.debug(
815 "Successfully resynced the device list for %s", user_id,
818 "Successfully resynced the device list for %s",
819 user_id,
816820 )
817821 except Exception as e:
818822 # If there was an issue resyncing this user, e.g. if the remote
819823 # server sent a malformed result, just log the error instead of
820824 # aborting all the subsequent resyncs.
821825 logger.debug(
822 "Could not resync the device list for %s: %s", user_id, e,
826 "Could not resync the device list for %s: %s",
827 user_id,
828 e,
823829 )
824830 finally:
825831 # Allow future calls to retry resyncinc out of sync device lists.
854860 return None
855861 except (RequestSendFailed, HttpResponseException) as e:
856862 logger.warning(
857 "Failed to handle device list update for %s: %s", user_id, e,
863 "Failed to handle device list update for %s: %s",
864 user_id,
865 e,
858866 )
859867
860868 if mark_failed_as_stale:
930938
931939 # Handle cross-signing keys.
932940 cross_signing_device_ids = await self.process_cross_signing_key_update(
933 user_id, master_key, self_signing_key,
941 user_id,
942 master_key,
943 self_signing_key,
934944 )
935945 device_ids = device_ids + cross_signing_device_ids
936946
6161 )
6262 else:
6363 hs.get_federation_registry().register_instances_for_edu(
64 "m.direct_to_device", hs.config.worker.writers.to_device,
64 "m.direct_to_device",
65 hs.config.worker.writers.to_device,
6566 )
6667
6768 # The handler to call when we think a user's device list might be out of
7273 hs.get_device_handler().device_list_updater.user_device_resync
7374 )
7475 else:
75 self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
76 hs
76 self._user_device_resync = (
77 ReplicationUserDevicesResyncRestServlet.make_client(hs)
7778 )
7879
7980 async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
6060
6161 self._is_master = hs.config.worker_app is None
6262 if not self._is_master:
63 self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
64 hs
63 self._user_device_resync_client = (
64 ReplicationUserDevicesResyncRestServlet.make_client(hs)
6565 )
6666 else:
6767 # Only register this edu handler on master as it requires writing
8484 async def query_devices(
8585 self, query_body: JsonDict, timeout: int, from_user_id: str
8686 ) -> JsonDict:
87 """ Handle a device key query from a client
87 """Handle a device key query from a client
8888
8989 {
9090 "device_keys": {
390390 async def on_federation_query_client_keys(
391391 self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
392392 ) -> JsonDict:
393 """ Handle a device key query from a federated server
394 """
393 """Handle a device key query from a federated server"""
395394 device_keys_query = query_body.get(
396395 "device_keys", {}
397396 ) # type: Dict[str, Optional[List[str]]]
10641063 return key, key_id, verify_key
10651064
10661065 async def _retrieve_cross_signing_keys_for_remote_user(
1067 self, user: UserID, desired_key_type: str,
1066 self,
1067 user: UserID,
1068 desired_key_type: str,
10681069 ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
10691070 """Queries cross-signing keys for a remote user and saves them to the database
10701071
12681269
12691270 @attr.s(slots=True)
12701271 class SignatureListItem:
1271 """An item in the signature list as used by upload_signatures_for_device_keys.
1272 """
1272 """An item in the signature list as used by upload_signatures_for_device_keys."""
12731273
12741274 signing_key_id = attr.ib(type=str)
12751275 target_user_id = attr.ib(type=str)
13541354 logger.info("pending updates: %r", pending_updates)
13551355
13561356 for master_key, self_signing_key in pending_updates:
1357 new_device_ids = await device_list_updater.process_cross_signing_key_update(
1358 user_id, master_key, self_signing_key,
1357 new_device_ids = (
1358 await device_list_updater.process_cross_signing_key_update(
1359 user_id,
1360 master_key,
1361 self_signing_key,
1362 )
13591363 )
13601364 device_ids = device_ids + new_device_ids
13611365
5656 room_id: Optional[str] = None,
5757 is_guest: bool = False,
5858 ) -> JsonDict:
59 """Fetches the events stream for a given user.
60 """
59 """Fetches the events stream for a given user."""
6160
6261 if room_id:
6362 blocked = await self.store.is_room_blocked(room_id)
110110
111111 class FederationHandler(BaseHandler):
112112 """Handles events that originated from federation.
113 Responsible for:
114 a) handling received Pdus before handing them on as Events to the rest
115 of the homeserver (including auth and state conflict resolutions)
116 b) converting events that were produced by local clients that may need
117 to be sent to remote homeservers.
118 c) doing the necessary dances to invite remote users and join remote
119 rooms.
113 Responsible for:
114 a) handling received Pdus before handing them on as Events to the rest
115 of the homeserver (including auth and state conflict resolutions)
116 b) converting events that were produced by local clients that may need
117 to be sent to remote homeservers.
118 c) doing the necessary dances to invite remote users and join remote
119 rooms.
120120 """
121121
122122 def __init__(self, hs: "HomeServer"):
149149 )
150150
151151 if hs.config.worker_app:
152 self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
153 hs
154 )
155 self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
156 hs
152 self._user_device_resync = (
153 ReplicationUserDevicesResyncRestServlet.make_client(hs)
154 )
155 self._maybe_store_room_on_outlier_membership = (
156 ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
157157 )
158158 else:
159159 self._device_list_updater = hs.get_device_handler().device_list_updater
171171 self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
172172
173173 async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
174 """ Process a PDU received via a federation /send/ transaction, or
174 """Process a PDU received via a federation /send/ transaction, or
175175 via backfill of missing prev_events
176176
177177 Args:
367367 # know about
368368 for p in prevs - seen:
369369 logger.info(
370 "Requesting state at missing prev_event %s", event_id,
370 "Requesting state at missing prev_event %s",
371 event_id,
371372 )
372373
373374 with nested_logging_context(p):
387388 event_map[x.event_id] = x
388389
389390 room_version = await self.store.get_room_version_id(room_id)
390 state_map = await self._state_resolution_handler.resolve_events_with_store(
391 room_id,
392 room_version,
393 state_maps,
394 event_map,
395 state_res_store=StateResolutionStore(self.store),
391 state_map = (
392 await self._state_resolution_handler.resolve_events_with_store(
393 room_id,
394 room_version,
395 state_maps,
396 event_map,
397 state_res_store=StateResolutionStore(self.store),
398 )
396399 )
397400
398401 # We need to give _process_received_pdu the actual state events
686689 return fetched_events
687690
688691 async def _process_received_pdu(
689 self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
692 self,
693 origin: str,
694 event: EventBase,
695 state: Optional[Iterable[EventBase]],
690696 ):
691 """ Called when we have a new pdu. We need to do auth checks and put it
697 """Called when we have a new pdu. We need to do auth checks and put it
692698 through the StateHandler.
693699
694700 Args:
800806
801807 @log_function
802808 async def backfill(self, dest, room_id, limit, extremities):
803 """ Trigger a backfill request to `dest` for the given `room_id`
809 """Trigger a backfill request to `dest` for the given `room_id`
804810
805811 This will attempt to get more events from the remote. If the other side
806812 has no new events to offer, this will return an empty list.
12031209 with nested_logging_context(event_id):
12041210 try:
12051211 event = await self.federation_client.get_pdu(
1206 [destination], event_id, room_version, outlier=True,
1212 [destination],
1213 event_id,
1214 room_version,
1215 outlier=True,
12071216 )
12081217 if event is None:
12091218 logger.warning(
1210 "Server %s didn't return event %s", destination, event_id,
1219 "Server %s didn't return event %s",
1220 destination,
1221 event_id,
12111222 )
12121223 return
12131224
12341245 if aid not in event_map
12351246 ]
12361247 persisted_events = await self.store.get_events(
1237 auth_events, allow_rejected=True,
1248 auth_events,
1249 allow_rejected=True,
12381250 )
12391251
12401252 event_infos = []
12501262 event_infos.append(_NewEventInfo(event, None, auth))
12511263
12521264 await self._handle_new_events(
1253 destination, room_id, event_infos,
1265 destination,
1266 room_id,
1267 event_infos,
12541268 )
12551269
12561270 def _sanity_check_event(self, ev):
12861300 raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
12871301
12881302 async def send_invite(self, target_host, event):
1289 """ Sends the invite to the remote server for signing.
1303 """Sends the invite to the remote server for signing.
12901304
12911305 Invites must be signed by the invitee's server before distribution.
12921306 """
13091323 async def do_invite_join(
13101324 self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
13111325 ) -> Tuple[str, int]:
1312 """ Attempts to join the `joinee` to the room `room_id` via the
1326 """Attempts to join the `joinee` to the room `room_id` via the
13131327 servers contained in `target_hosts`.
13141328
13151329 This first triggers a /make_join/ request that returns a partial
13521366 self.room_queues[room_id] = []
13531367
13541368 await self._clean_room_for_join(room_id)
1355
1356 handled_events = set()
13571369
13581370 try:
13591371 # Try the host we successfully got a response to /make_join/
13731385 state = ret["state"]
13741386 auth_chain = ret["auth_chain"]
13751387 auth_chain.sort(key=lambda e: e.depth)
1376
1377 handled_events.update([s.event_id for s in state])
1378 handled_events.update([a.event_id for a in auth_chain])
1379 handled_events.add(event.event_id)
13801388
13811389 logger.debug("do_invite_join auth_chain: %s", auth_chain)
13821390 logger.debug("do_invite_join state: %s", state)
13931401 # so we can rely on it now.
13941402 #
13951403 await self.store.upsert_room_on_join(
1396 room_id=room_id, room_version=room_version_obj,
1404 room_id=room_id,
1405 room_version=room_version_obj,
13971406 )
13981407
13991408 max_stream_id = await self._persist_auth_tree(
14631472 async def on_make_join_request(
14641473 self, origin: str, room_id: str, user_id: str
14651474 ) -> EventBase:
1466 """ We've received a /make_join/ request, so we create a partial
1475 """We've received a /make_join/ request, so we create a partial
14671476 join event for the room and return that. We do *not* persist or
14681477 process it until the other server has signed it and sent it back.
14691478
14881497 is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
14891498 if not is_in_room:
14901499 logger.info(
1491 "Got /make_join request for room %s we are no longer in", room_id,
1500 "Got /make_join request for room %s we are no longer in",
1501 room_id,
14921502 )
14931503 raise NotFoundError("Not an active room on this server")
14941504
15221532 return event
15231533
15241534 async def on_send_join_request(self, origin, pdu):
1525 """ We have received a join event for a room. Fully process it and
1535 """We have received a join event for a room. Fully process it and
15261536 respond with the current state and auth chains.
15271537 """
15281538 event = pdu
15781588 async def on_invite_request(
15791589 self, origin: str, event: EventBase, room_version: RoomVersion
15801590 ):
1581 """ We've got an invite event. Process and persist it. Sign it.
1591 """We've got an invite event. Process and persist it. Sign it.
15821592
15831593 Respond with the now signed event.
15841594 """
17051715 async def on_make_leave_request(
17061716 self, origin: str, room_id: str, user_id: str
17071717 ) -> EventBase:
1708 """ We've received a /make_leave/ request, so we create a partial
1718 """We've received a /make_leave/ request, so we create a partial
17091719 leave event for the room and return that. We do *not* persist or
17101720 process it until the other server has signed it and sent it back.
17111721
17811791 return None
17821792
17831793 async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
1784 """Returns the state at the event. i.e. not including said event.
1785 """
1794 """Returns the state at the event. i.e. not including said event."""
17861795
17871796 event = await self.store.get_event(event_id, check_room_id=room_id)
17881797
18081817 return []
18091818
18101819 async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
1811 """Returns the state at the event. i.e. not including said event.
1812 """
1820 """Returns the state at the event. i.e. not including said event."""
18131821 event = await self.store.get_event(event_id, check_room_id=room_id)
18141822
18151823 state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
20152023
20162024 for e_id in missing_auth_events:
20172025 m_ev = await self.federation_client.get_pdu(
2018 [origin], e_id, room_version=room_version, outlier=True, timeout=10000,
2026 [origin],
2027 e_id,
2028 room_version=room_version,
2029 outlier=True,
2030 timeout=10000,
20192031 )
20202032 if m_ev and m_ev.event_id == e_id:
20212033 event_map[e_id] = m_ev
21652177 )
21662178
21672179 logger.debug(
2168 "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
2180 "Doing soft-fail check for %s: state %s",
2181 event.event_id,
2182 current_state_ids,
21692183 )
21702184
21712185 # Now check if event pass auth against said current state
25182532 async def construct_auth_difference(
25192533 self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
25202534 ) -> Dict:
2521 """ Given a local and remote auth chain, find the differences. This
2535 """Given a local and remote auth chain, find the differences. This
25222536 assumes that we have already processed all events in remote_auth
25232537
25242538 Params:
145145 async def get_users_in_group(
146146 self, group_id: str, requester_user_id: str
147147 ) -> JsonDict:
148 """Get users in a group
149 """
148 """Get users in a group"""
150149 if self.is_mine_id(group_id):
151150 return await self.groups_server_handler.get_users_in_group(
152151 group_id, requester_user_id
282281 async def create_group(
283282 self, group_id: str, user_id: str, content: JsonDict
284283 ) -> JsonDict:
285 """Create a group
286 """
284 """Create a group"""
287285
288286 logger.info("Asking to create group with ID: %r", group_id)
289287
313311 async def join_group(
314312 self, group_id: str, user_id: str, content: JsonDict
315313 ) -> JsonDict:
316 """Request to join a group
317 """
314 """Request to join a group"""
318315 if self.is_mine_id(group_id):
319316 await self.groups_server_handler.join_group(group_id, user_id, content)
320317 local_attestation = None
360357 async def accept_invite(
361358 self, group_id: str, user_id: str, content: JsonDict
362359 ) -> JsonDict:
363 """Accept an invite to a group
364 """
360 """Accept an invite to a group"""
365361 if self.is_mine_id(group_id):
366362 await self.groups_server_handler.accept_invite(group_id, user_id, content)
367363 local_attestation = None
407403 async def invite(
408404 self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
409405 ) -> JsonDict:
410 """Invite a user to a group
411 """
406 """Invite a user to a group"""
412407 content = {"requester_user_id": requester_user_id, "config": config}
413408 if self.is_mine_id(group_id):
414409 res = await self.groups_server_handler.invite_to_group(
433428 async def on_invite(
434429 self, group_id: str, user_id: str, content: JsonDict
435430 ) -> JsonDict:
436 """One of our users were invited to a group
437 """
431 """One of our users were invited to a group"""
438432 # TODO: Support auto join and rejection
439433
440434 if not self.is_mine_id(user_id):
465459 async def remove_user_from_group(
466460 self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
467461 ) -> JsonDict:
468 """Remove a user from a group
469 """
462 """Remove a user from a group"""
470463 if user_id == requester_user_id:
471464 token = await self.store.register_user_group_membership(
472465 group_id, user_id, membership="leave"
500493 async def user_removed_from_group(
501494 self, group_id: str, user_id: str, content: JsonDict
502495 ) -> None:
503 """One of our users was removed/kicked from a group
504 """
496 """One of our users was removed/kicked from a group"""
505497 # TODO: Check if user in group
506498 token = await self.store.register_user_group_membership(
507499 group_id, user_id, membership="leave"
7171 )
7272
7373 def ratelimit_request_token_requests(
74 self, request: SynapseRequest, medium: str, address: str,
74 self,
75 request: SynapseRequest,
76 medium: str,
77 address: str,
7578 ):
7679 """Used to ratelimit requests to `/requestToken` by IP and address.
7780
123123
124124 joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
125125 receipt = await self.store.get_linearized_receipts_for_rooms(
126 joined_rooms, to_key=int(now_token.receipt_key),
126 joined_rooms,
127 to_key=int(now_token.receipt_key),
127128 )
128129
129130 tags_by_room = await self.store.get_tags_for_user(user_id)
168169 self.state_handler.get_current_state, event.room_id
169170 )
170171 elif event.membership == Membership.LEAVE:
171 room_end_token = RoomStreamToken(None, event.stream_ordering,)
172 room_end_token = RoomStreamToken(
173 None,
174 event.stream_ordering,
175 )
172176 deferred_room_state = run_in_background(
173177 self.state_store.get_state_for_events, [event.event_id]
174178 )
283287 membership,
284288 member_event_id,
285289 ) = await self.auth.check_user_in_room_or_world_readable(
286 room_id, user_id, allow_departed_users=True,
290 room_id,
291 user_id,
292 allow_departed_users=True,
287293 )
288294 is_peeking = member_event_id is None
289295
6464
6565
6666 class MessageHandler:
67 """Contains some read only APIs to get state about a room
68 """
67 """Contains some read only APIs to get state about a room"""
6968
7069 def __init__(self, hs):
7170 self.auth = hs.get_auth()
8786 )
8887
8988 async def get_room_data(
90 self, user_id: str, room_id: str, event_type: str, state_key: str,
89 self,
90 user_id: str,
91 room_id: str,
92 event_type: str,
93 state_key: str,
9194 ) -> dict:
92 """ Get data from a room.
95 """Get data from a room.
9396
9497 Args:
9598 user_id
173176 raise NotFoundError("Can't find event for token %s" % (at_token,))
174177
175178 visible_events = await filter_events_for_client(
176 self.storage, user_id, last_events, filter_send_to_client=False,
179 self.storage,
180 user_id,
181 last_events,
182 filter_send_to_client=False,
177183 )
178184
179185 event = last_events[0]
570576 async def _is_exempt_from_privacy_policy(
571577 self, builder: EventBuilder, requester: Requester
572578 ) -> bool:
573 """"Determine if an event to be sent is exempt from having to consent
579 """ "Determine if an event to be sent is exempt from having to consent
574580 to the privacy policy
575581
576582 Args:
792798 """
793799
794800 if prev_event_ids is not None:
795 assert len(prev_event_ids) <= 10, (
796 "Attempting to create an event with %i prev_events"
797 % (len(prev_event_ids),)
801 assert (
802 len(prev_event_ids) <= 10
803 ), "Attempting to create an event with %i prev_events" % (
804 len(prev_event_ids),
798805 )
799806 else:
800807 prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
820827 )
821828 if not third_party_result:
822829 logger.info(
823 "Event %s forbidden by third-party rules", event,
830 "Event %s forbidden by third-party rules",
831 event,
824832 )
825833 raise SynapseError(
826834 403, "This event is not allowed in this context", Codes.FORBIDDEN
13151323 # Since this is a dummy-event it is OK if it is sent by a
13161324 # shadow-banned user.
13171325 await self.handle_new_client_event(
1318 requester, event, context, ratelimit=False, ignore_shadow_ban=True,
1326 requester,
1327 event,
1328 context,
1329 ratelimit=False,
1330 ignore_shadow_ban=True,
13191331 )
13201332 return True
13211333 except AuthError:
4040 from synapse.logging.context import make_deferred_yieldable
4141 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
4242 from synapse.util import json_decoder
43 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
4344
4445 if TYPE_CHECKING:
4546 from synapse.server import HomeServer
4647
4748 logger = logging.getLogger(__name__)
4849
49 SESSION_COOKIE_NAME = b"oidc_session"
50 # we want the cookie to be returned to us even when the request is the POSTed
51 # result of a form on another domain, as is used with `response_mode=form_post`.
52 #
53 # Modern browsers will not do so unless we set SameSite=None; however *older*
54 # browsers (including all versions of Safari on iOS 12?) don't support
55 # SameSite=None, and interpret it as SameSite=Strict:
56 # https://bugs.webkit.org/show_bug.cgi?id=198181
57 #
58 # As a rather painful workaround, we set *two* cookies, one with SameSite=None
59 # and one with no SameSite, in the hope that at least one of them will get
60 # back to us.
61 #
62 # Secure is necessary for SameSite=None (and, empirically, also breaks things
63 # on iOS 12.)
64 #
65 # Here we have the names of the cookies, and the options we use to set them.
66 _SESSION_COOKIES = [
67 (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"),
68 (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"),
69 ]
5070
5171 #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
5272 #: OpenID.Core sec 3.1.3.3.
7191
7292
7393 class OidcHandler:
74 """Handles requests related to the OpenID Connect login flow.
75 """
94 """Handles requests related to the OpenID Connect login flow."""
7695
7796 def __init__(self, hs: "HomeServer"):
7897 self._sso_handler = hs.get_sso_handler()
122141 Args:
123142 request: the incoming request from the browser.
124143 """
125
126144 # The provider might redirect with an error.
127145 # In that case, just display it as-is.
128146 if b"error" in request.args:
136154 # either the provider misbehaving or Synapse being misconfigured.
137155 # The only exception of that is "access_denied", where the user
138156 # probably cancelled the login flow. In other cases, log those errors.
139 if error != "access_denied":
140 logger.error("Error from the OIDC provider: %s %s", error, description)
157 logger.log(
158 logging.INFO if error == "access_denied" else logging.ERROR,
159 "Received OIDC callback with error: %s %s",
160 error,
161 description,
162 )
141163
142164 self._sso_handler.render_error(request, error, description)
143165 return
145167 # otherwise, it is presumably a successful response. see:
146168 # https://tools.ietf.org/html/rfc6749#section-4.1.2
147169
148 # Fetch the session cookie
149 session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
150 if session is None:
151 logger.info("No session cookie found")
170 # Fetch the session cookie. See the comments on SESSION_COOKIES for why there
171 # are two.
172
173 for cookie_name, _ in _SESSION_COOKIES:
174 session = request.getCookie(cookie_name) # type: Optional[bytes]
175 if session is not None:
176 break
177 else:
178 logger.info("Received OIDC callback, with no session cookie")
152179 self._sso_handler.render_error(
153180 request, "missing_session", "No session cookie found"
154181 )
155182 return
156183
157 # Remove the cookie. There is a good chance that if the callback failed
184 # Remove the cookies. There is a good chance that if the callback failed
158185 # once, it will fail next time and the code will already be exchanged.
159 # Removing it early avoids spamming the provider with token requests.
160 request.addCookie(
161 SESSION_COOKIE_NAME,
162 b"",
163 path="/_synapse/oidc",
164 expires="Thu, Jan 01 1970 00:00:00 UTC",
165 httpOnly=True,
166 sameSite="lax",
167 )
186 # Removing the cookies early avoids spamming the provider with token requests.
187 #
188 # we have to build the header by hand rather than calling request.addCookie
189 # because the latter does not support SameSite=None
190 # (https://twistedmatrix.com/trac/ticket/10088)
191
192 for cookie_name, options in _SESSION_COOKIES:
193 request.cookies.append(
194 b"%s=; Expires=Thu, Jan 01 1970 00:00:00 UTC; %s"
195 % (cookie_name, options)
196 )
168197
169198 # Check for the state query parameter
170199 if b"state" not in request.args:
171 logger.info("State parameter is missing")
200 logger.info("Received OIDC callback, with no state parameter")
172201 self._sso_handler.render_error(
173202 request, "invalid_request", "State parameter is missing"
174203 )
182211 session, state
183212 )
184213 except (MacaroonDeserializationException, ValueError) as e:
185 logger.exception("Invalid session")
214 logger.exception("Invalid session for OIDC callback")
186215 self._sso_handler.render_error(request, "invalid_session", str(e))
187216 return
188217 except MacaroonInvalidSignatureException as e:
189 logger.exception("Could not verify session")
218 logger.exception("Could not verify session for OIDC callback")
190219 self._sso_handler.render_error(request, "mismatching_session", str(e))
191220 return
221
222 logger.info("Received OIDC callback for IdP %s", session_data.idp_id)
192223
193224 oidc_provider = self._providers.get(session_data.idp_id)
194225 if not oidc_provider:
209240
210241
211242 class OidcError(Exception):
212 """Used to catch errors when calling the token_endpoint
213 """
243 """Used to catch errors when calling the token_endpoint"""
214244
215245 def __init__(self, error, error_description=None):
216246 self.error = error
239269
240270 self._token_generator = token_generator
241271
272 self._config = provider
242273 self._callback_url = hs.config.oidc_callback_url # type: str
243274
244275 self._scopes = provider.scopes
245276 self._user_profile_method = provider.user_profile_method
246277 self._client_auth = ClientAuth(
247 provider.client_id, provider.client_secret, provider.client_auth_method,
278 provider.client_id,
279 provider.client_secret,
280 provider.client_auth_method,
248281 ) # type: ClientAuth
249282 self._client_auth_method = provider.client_auth_method
250 self._provider_metadata = OpenIDProviderMetadata(
251 issuer=provider.issuer,
252 authorization_endpoint=provider.authorization_endpoint,
253 token_endpoint=provider.token_endpoint,
254 userinfo_endpoint=provider.userinfo_endpoint,
255 jwks_uri=provider.jwks_uri,
256 ) # type: OpenIDProviderMetadata
257 self._provider_needs_discovery = provider.discover
283
284 # cache of metadata for the identity provider (endpoint uris, mostly). This is
285 # loaded on-demand from the discovery endpoint (if discovery is enabled), with
286 # possible overrides from the config. Access via `load_metadata`.
287 self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
288
289 # cache of JWKs used by the identity provider to sign tokens. Loaded on demand
290 # from the IdP's jwks_uri, if required.
291 self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
292
258293 self._user_mapping_provider = provider.user_mapping_provider_class(
259294 provider.user_mapping_provider_config
260295 )
280315
281316 self._sso_handler.register_identity_provider(self)
282317
283 def _validate_metadata(self):
318 def _validate_metadata(self, m: OpenIDProviderMetadata) -> None:
284319 """Verifies the provider metadata.
285320
286321 This checks the validity of the currently loaded provider. Not
299334 if self._skip_verification is True:
300335 return
301336
302 m = self._provider_metadata
303337 m.validate_issuer()
304338 m.validate_authorization_endpoint()
305339 m.validate_token_endpoint()
334368 )
335369 else:
336370 # If we're not using userinfo, we need a valid jwks to validate the ID token
337 if m.get("jwks") is None:
338 if m.get("jwks_uri") is not None:
339 m.validate_jwks_uri()
340 else:
341 raise ValueError('"jwks_uri" must be set')
371 m.validate_jwks_uri()
342372
343373 @property
344374 def _uses_userinfo(self) -> bool:
355385 or self._user_profile_method == "userinfo_endpoint"
356386 )
357387
358 async def load_metadata(self) -> OpenIDProviderMetadata:
359 """Load and validate the provider metadata.
360
361 The values metadatas are discovered if ``oidc_config.discovery`` is
362 ``True`` and then cached.
388 async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
389 """Return the provider metadata.
390
391 If this is the first call, the metadata is built from the config and from the
392 metadata discovery endpoint (if enabled), and then validated. If the metadata
393 is successfully validated, it is then cached for future use.
394
395 Args:
396 force: If true, any cached metadata is discarded to force a reload.
363397
364398 Raises:
365399 ValueError: if something in the provider is not valid
367401 Returns:
368402 The provider's metadata.
369403 """
370 # If we are using the OpenID Discovery documents, it needs to be loaded once
371 # FIXME: should there be a lock here?
372 if self._provider_needs_discovery:
373 url = get_well_known_url(self._provider_metadata["issuer"], external=True)
404 if force:
405 # reset the cached call to ensure we get a new result
406 self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata)
407
408 return await self._provider_metadata.get()
409
410 async def _load_metadata(self) -> OpenIDProviderMetadata:
411 # start out with just the issuer (unlike the other settings, discovered issuer
412 # takes precedence over configured issuer, because configured issuer is
413 # required for discovery to take place.)
414 #
415 metadata = OpenIDProviderMetadata(issuer=self._config.issuer)
416
417 # load any data from the discovery endpoint, if enabled
418 if self._config.discover:
419 url = get_well_known_url(self._config.issuer, external=True)
374420 metadata_response = await self._http_client.get_json(url)
375 # TODO: maybe update the other way around to let user override some values?
376 self._provider_metadata.update(metadata_response)
377 self._provider_needs_discovery = False
378
379 self._validate_metadata()
380
381 return self._provider_metadata
421 metadata.update(metadata_response)
422
423 # override any discovered data with any settings in our config
424 if self._config.authorization_endpoint:
425 metadata["authorization_endpoint"] = self._config.authorization_endpoint
426
427 if self._config.token_endpoint:
428 metadata["token_endpoint"] = self._config.token_endpoint
429
430 if self._config.userinfo_endpoint:
431 metadata["userinfo_endpoint"] = self._config.userinfo_endpoint
432
433 if self._config.jwks_uri:
434 metadata["jwks_uri"] = self._config.jwks_uri
435
436 self._validate_metadata(metadata)
437
438 return metadata
382439
383440 async def load_jwks(self, force: bool = False) -> JWKS:
384441 """Load the JSON Web Key Set used to sign ID tokens.
408465 ]
409466 }
410467 """
468 if force:
469 # reset the cached call to ensure we get a new result
470 self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
471 return await self._jwks.get()
472
473 async def _load_jwks(self) -> JWKS:
411474 if self._uses_userinfo:
412475 # We're not using jwt signing, return an empty jwk set
413476 return {"keys": []}
414477
415 # First check if the JWKS are loaded in the provider metadata.
416 # It can happen either if the provider gives its JWKS in the discovery
417 # document directly or if it was already loaded once.
418478 metadata = await self.load_metadata()
419 jwk_set = metadata.get("jwks")
420 if jwk_set is not None and not force:
421 return jwk_set
422
423 # Loading the JWKS using the `jwks_uri` metadata
479
480 # Load the JWKS using the `jwks_uri` metadata.
424481 uri = metadata.get("jwks_uri")
425482 if not uri:
483 # this should be unreachable: load_metadata validates that
484 # there is a jwks_uri in the metadata if _uses_userinfo is unset
426485 raise RuntimeError('Missing "jwks_uri" in metadata')
427486
428487 jwk_set = await self._http_client.get_json(uri)
429488
430 # Caching the JWKS in the provider's metadata
431 self._provider_metadata["jwks"] = jwk_set
432489 return jwk_set
433490
434491 async def _exchange_code(self, code: str) -> Token:
486543 # We're not using the SimpleHttpClient util methods as we don't want to
487544 # check the HTTP status code and we do the body encoding ourself.
488545 response = await self._http_client.request(
489 method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
546 method="POST",
547 uri=uri,
548 data=body.encode("utf-8"),
549 headers=headers,
490550 )
491551
492552 # This is used in multiple error messages below
564624 Returns:
565625 UserInfo: an object representing the user.
566626 """
627 logger.debug("Using the OAuth2 access_token to request userinfo")
567628 metadata = await self.load_metadata()
568629
569630 resp = await self._http_client.get_json(
570631 metadata["userinfo_endpoint"],
571632 headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
572633 )
634
635 logger.debug("Retrieved user info from userinfo endpoint: %r", resp)
573636
574637 return UserInfo(resp)
575638
599662 claims_cls = ImplicitIDToken
600663
601664 alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
602
603665 jwt = JsonWebToken(alg_values)
604666
605667 claim_options = {"iss": {"values": [metadata["issuer"]]}}
668
669 id_token = token["id_token"]
670 logger.debug("Attempting to decode JWT id_token %r", id_token)
606671
607672 # Try to decode the keys in cache first, then retry by forcing the keys
608673 # to be reloaded
609674 jwk_set = await self.load_jwks()
610675 try:
611676 claims = jwt.decode(
612 token["id_token"],
677 id_token,
613678 key=jwk_set,
614679 claims_cls=claims_cls,
615680 claims_options=claim_options,
619684 logger.info("Reloading JWKS after decode error")
620685 jwk_set = await self.load_jwks(force=True) # try reloading the jwks
621686 claims = jwt.decode(
622 token["id_token"],
687 id_token,
623688 key=jwk_set,
624689 claims_cls=claims_cls,
625690 claims_options=claim_options,
626691 claims_params=claims_params,
627692 )
693
694 logger.debug("Decoded id_token JWT %r; validating", claims)
628695
629696 claims.validate(leeway=120) # allows 2 min of clock skew
630697 return UserInfo(claims)
680747 ui_auth_session_id=ui_auth_session_id,
681748 ),
682749 )
683 request.addCookie(
684 SESSION_COOKIE_NAME,
685 cookie,
686 path="/_synapse/client/oidc",
687 max_age="3600",
688 httpOnly=True,
689 sameSite="lax",
690 )
750
751 # Set the cookies. See the comments on _SESSION_COOKIES for why there are two.
752 #
753 # we have to build the header by hand rather than calling request.addCookie
754 # because the latter does not support SameSite=None
755 # (https://twistedmatrix.com/trac/ticket/10088)
756
757 for cookie_name, options in _SESSION_COOKIES:
758 request.cookies.append(
759 b"%s=%s; Max-Age=3600; %s"
760 % (cookie_name, cookie.encode("utf-8"), options)
761 )
691762
692763 metadata = await self.load_metadata()
693764 authorization_endpoint = metadata.get("authorization_endpoint")
725796 """
726797 # Exchange the code with the provider
727798 try:
728 logger.debug("Exchanging code")
799 logger.debug("Exchanging OAuth2 code for a token")
729800 token = await self._exchange_code(code)
730801 except OidcError as e:
731 logger.exception("Could not exchange code")
802 logger.exception("Could not exchange OAuth2 code")
732803 self._sso_handler.render_error(request, e.error, e.error_description)
733804 return
734805
735 logger.debug("Successfully obtained OAuth2 access token")
806 logger.debug("Successfully obtained OAuth2 token data: %r", token)
736807
737808 # Now that we have a token, get the userinfo, either by decoding the
738809 # `id_token` or by fetching the `userinfo_endpoint`.
739810 if self._uses_userinfo:
740 logger.debug("Fetching userinfo")
741811 try:
742812 userinfo = await self._fetch_userinfo(token)
743813 except Exception as e:
745815 self._sso_handler.render_error(request, "fetch_error", str(e))
746816 return
747817 else:
748 logger.debug("Extracting userinfo from id_token")
749818 try:
750819 userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
751820 except Exception as e:
9381007 A signed macaroon token with the session information.
9391008 """
9401009 macaroon = pymacaroons.Macaroon(
941 location=self._server_name, identifier="key", key=self._macaroon_secret_key,
1010 location=self._server_name,
1011 identifier="key",
1012 key=self._macaroon_secret_key,
9421013 )
9431014 macaroon.add_first_party_caveat("gen = 1")
9441015 macaroon.add_first_party_caveat("type = session")
196196 stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
197197
198198 r = await self.store.get_room_event_before_stream_ordering(
199 room_id, stream_ordering,
199 room_id,
200 stream_ordering,
200201 )
201202 if not r:
202203 logger.warning(
222223 # the background so that it's not blocking any other operation apart from
223224 # other purges in the same room.
224225 run_as_background_process(
225 "_purge_history", self._purge_history, purge_id, room_id, token, True,
226 "_purge_history",
227 self._purge_history,
228 purge_id,
229 room_id,
230 token,
231 True,
226232 )
227233
228234 def start_purge_history(
388394 )
389395
390396 await self.hs.get_federation_handler().maybe_backfill(
391 room_id, curr_topo, limit=pagin_config.limit,
397 room_id,
398 curr_topo,
399 limit=pagin_config.limit,
392400 )
393401
394402 to_room_key = None
348348 [self.user_to_current_state[user_id] for user_id in unpersisted]
349349 )
350350
351 async def _update_states(self, new_states):
351 async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None:
352352 """Updates presence of users. Sets the appropriate timeouts. Pokes
353353 the notifier and federation if and only if the changed presence state
354354 should be sent to clients/servers.
355
356 Args:
357 new_states: The new user presence state updates to process.
355358 """
356359 now = self.clock.time_msec()
357360
367370 new_states_dict = {}
368371 for new_state in new_states:
369372 new_states_dict[new_state.user_id] = new_state
370 new_state = new_states_dict.values()
373 new_states = new_states_dict.values()
371374
372375 for new_state in new_states:
373376 user_id = new_state.user_id
634637 self.external_process_last_updated_ms.pop(process_id, None)
635638
636639 async def current_state_for_user(self, user_id):
637 """Get the current presence state for a user.
638 """
640 """Get the current presence state for a user."""
639641 res = await self.current_state_for_users([user_id])
640642 return res[user_id]
641643
657659
658660 self._push_to_remotes(states)
659661
660 async def notify_for_states(self, state, stream_id):
661 parties = await get_interested_parties(self.store, [state])
662 room_ids_to_states, users_to_states = parties
663
664 self.notifier.on_new_event(
665 "presence_key",
666 stream_id,
667 rooms=room_ids_to_states.keys(),
668 users=[UserID.from_string(u) for u in users_to_states],
669 )
670
671662 def _push_to_remotes(self, states):
672663 """Sends state updates to remote servers.
673664
677668 self.federation.send_presence(states)
678669
679670 async def incoming_presence(self, origin, content):
680 """Called when we receive a `m.presence` EDU from a remote server.
681 """
671 """Called when we receive a `m.presence` EDU from a remote server."""
682672 if not self._presence_enabled:
683673 return
684674
728718 await self._update_states(updates)
729719
730720 async def set_state(self, target_user, state, ignore_status_msg=False):
731 """Set the presence state of the user.
732 """
721 """Set the presence state of the user."""
733722 status_msg = state.get("status_msg", None)
734723 presence = state["presence"]
735724
757746 await self._update_states([prev_state.copy_and_replace(**new_fields)])
758747
759748 async def is_visible(self, observed_user, observer_user):
760 """Returns whether a user can see another user's presence.
761 """
749 """Returns whether a user can see another user's presence."""
762750 observer_room_ids = await self.store.get_rooms_for_user(
763751 observer_user.to_string()
764752 )
952940
953941
954942 def should_notify(old_state, new_state):
955 """Decides if a presence state change should be sent to interested parties.
956 """
943 """Decides if a presence state change should be sent to interested parties."""
957944 if old_state == new_state:
958945 return False
959946
206206 # This must be done by the target user himself.
207207 if by_admin:
208208 requester = create_requester(
209 target_user, authenticated_entity=requester.authenticated_entity,
209 target_user,
210 authenticated_entity=requester.authenticated_entity,
210211 )
211212
212213 await self.store.set_profile_displayname(
4848 )
4949 else:
5050 hs.get_federation_registry().register_instances_for_edu(
51 "m.receipt", hs.config.worker.writers.receipts,
51 "m.receipt",
52 hs.config.worker.writers.receipts,
5253 )
5354
5455 self.clock = self.hs.get_clock()
5556 self.state = hs.get_state_handler()
5657
5758 async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
58 """Called when we receive an EDU of type m.receipt from a remote HS.
59 """
59 """Called when we receive an EDU of type m.receipt from a remote HS."""
6060 receipts = []
6161 for room_id, room_values in content.items():
6262 for receipt_type, users in room_values.items():
8282 await self._handle_new_receipts(receipts)
8383
8484 async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
85 """Takes a list of receipts, stores them and informs the notifier.
86 """
85 """Takes a list of receipts, stores them and informs the notifier."""
8786 min_batch_id = None # type: Optional[int]
8887 max_batch_id = None # type: Optional[int]
8988
6161 self._register_device_client = RegisterDeviceReplicationServlet.make_client(
6262 hs
6363 )
64 self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client(
65 hs
64 self._post_registration_client = (
65 ReplicationPostRegisterActionsServlet.make_client(hs)
6666 )
6767 else:
6868 self.device_handler = hs.get_device_handler()
188188 self.check_registration_ratelimit(address)
189189
190190 result = await self.spam_checker.check_registration_for_spam(
191 threepid, localpart, user_agent_ips or [],
191 threepid,
192 localpart,
193 user_agent_ips or [],
192194 )
193195
194196 if result == RegistrationBehaviour.DENY:
195197 logger.info(
196 "Blocked registration of %r", localpart,
198 "Blocked registration of %r",
199 localpart,
197200 )
198201 # We return a 429 to make it not obvious that they've been
199202 # denied.
202205 shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
203206 if shadow_banned:
204207 logger.info(
205 "Shadow banning registration of %r", localpart,
208 "Shadow banning registration of %r",
209 localpart,
206210 )
207211
208212 # do not check_auth_blocking if the call is coming through the Admin API
368372 config["room_alias_name"] = room_alias.localpart
369373
370374 info, _ = await room_creation_handler.create_room(
371 fake_requester, config=config, ratelimit=False,
375 fake_requester,
376 config=config,
377 ratelimit=False,
372378 )
373379
374380 # If the room does not require an invite, but another user
752758 return
753759
754760 await self._auth_handler.add_threepid(
755 user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
761 user_id,
762 threepid["medium"],
763 threepid["address"],
764 threepid["validated_at"],
756765 )
757766
758767 # And we add an email pusher for them by default, but only
804813 raise
805814
806815 await self._auth_handler.add_threepid(
807 user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
816 user_id,
817 threepid["medium"],
818 threepid["address"],
819 threepid["validated_at"],
808820 )
3737 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
3838 from synapse.events import EventBase
3939 from synapse.events.utils import copy_power_levels_contents
40 from synapse.rest.admin._base import assert_user_is_admin
4041 from synapse.storage.state import StateFilter
4142 from synapse.types import (
4243 JsonDict,
196197 if r is None:
197198 raise NotFoundError("Unknown room id %s" % (old_room_id,))
198199 new_room_id = await self._generate_room_id(
199 creator_id=user_id, is_public=r["is_public"], room_version=new_version,
200 creator_id=user_id,
201 is_public=r["is_public"],
202 room_version=new_version,
200203 )
201204
202205 logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
234237
235238 # now send the tombstone
236239 await self.event_creation_handler.handle_new_client_event(
237 requester=requester, event=tombstone_event, context=tombstone_context,
240 requester=requester,
241 event=tombstone_event,
242 context=tombstone_context,
238243 )
239244
240245 old_room_state = await tombstone_context.get_current_state_ids()
255260 # finally, shut down the PLs in the old room, and update them in the new
256261 # room.
257262 await self._update_upgraded_room_pls(
258 requester, old_room_id, new_room_id, old_room_state,
263 requester,
264 old_room_id,
265 new_room_id,
266 old_room_state,
259267 )
260268
261269 return new_room_id
423431
424432 # Copy over user power levels now as this will not be possible with >100PL users once
425433 # the room has been created
426
427434 # Calculate the minimum power level needed to clone the room
428435 event_power_levels = power_levels.get("events", {})
429 state_default = power_levels.get("state_default", 0)
430 ban = power_levels.get("ban")
436 state_default = power_levels.get("state_default", 50)
437 ban = power_levels.get("ban", 50)
431438 needed_power_level = max(state_default, ban, max(event_power_levels.values()))
432439
440 # Get the user's current power level, this matches the logic in get_user_power_level,
441 # but without the entire state map.
442 user_power_levels = power_levels.setdefault("users", {})
443 users_default = power_levels.get("users_default", 0)
444 current_power_level = user_power_levels.get(user_id, users_default)
433445 # Raise the requester's power level in the new room if necessary
434 current_power_level = power_levels["users"][user_id]
435446 if current_power_level < needed_power_level:
436 power_levels["users"][user_id] = needed_power_level
447 user_power_levels[user_id] = needed_power_level
437448
438449 await self._send_events_for_new_room(
439450 requester,
565576 ratelimit: bool = True,
566577 creator_join_profile: Optional[JsonDict] = None,
567578 ) -> Tuple[dict, int]:
568 """ Creates a new room.
579 """Creates a new room.
569580
570581 Args:
571582 requester:
686697 is_public = visibility == "public"
687698
688699 room_id = await self._generate_room_id(
689 creator_id=user_id, is_public=is_public, room_version=room_version,
700 creator_id=user_id,
701 is_public=is_public,
702 room_version=room_version,
690703 )
691704
692705 # Check whether this visibility value is blocked by a third party module
827840 if room_alias:
828841 result["room_alias"] = room_alias.to_string()
829842
830 # Always wait for room creation to progate before returning
843 # Always wait for room creation to propagate before returning
831844 await self._replication.wait_for_stream_position(
832845 self.hs.config.worker.events_shard_config.get_instance(room_id),
833846 "events",
879892 _,
880893 last_stream_id,
881894 ) = await self.event_creation_handler.create_and_send_nonmember_event(
882 creator, event, ratelimit=False, ignore_shadow_ban=True,
895 creator,
896 event,
897 ratelimit=False,
898 ignore_shadow_ban=True,
883899 )
884900 return last_stream_id
885901
979995 return last_sent_stream_id
980996
981997 async def _generate_room_id(
982 self, creator_id: str, is_public: bool, room_version: RoomVersion,
998 self,
999 creator_id: str,
1000 is_public: bool,
1001 room_version: RoomVersion,
9831002 ):
9841003 # autogen room IDs and try to create it. We may clash, so just
9851004 # try a few times till one goes through, giving up eventually.
10031022 class RoomContextHandler:
10041023 def __init__(self, hs: "HomeServer"):
10051024 self.hs = hs
1025 self.auth = hs.get_auth()
10061026 self.store = hs.get_datastore()
10071027 self.storage = hs.get_storage()
10081028 self.state_store = self.storage.state
10091029
10101030 async def get_event_context(
10111031 self,
1012 user: UserID,
1032 requester: Requester,
10131033 room_id: str,
10141034 event_id: str,
10151035 limit: int,
10161036 event_filter: Optional[Filter],
1037 use_admin_priviledge: bool = False,
10171038 ) -> Optional[JsonDict]:
10181039 """Retrieves events, pagination tokens and state around a given event
10191040 in a room.
10201041
10211042 Args:
1022 user
1043 requester
10231044 room_id
10241045 event_id
10251046 limit: The maximum number of events to return in total
10261047 (excluding state).
10271048 event_filter: the filter to apply to the events returned
10281049 (excluding the target event_id)
1029
1050 use_admin_priviledge: if `True`, return all events, regardless
1051 of whether `user` has access to them. To be used **ONLY**
1052 from the admin API.
10301053 Returns:
10311054 dict, or None if the event isn't found
10321055 """
1056 user = requester.user
1057 if use_admin_priviledge:
1058 await assert_user_is_admin(self.auth, requester.user)
1059
10331060 before_limit = math.floor(limit / 2.0)
10341061 after_limit = limit - before_limit
10351062
10361063 users = await self.store.get_users_in_room(room_id)
10371064 is_peeking = user.to_string() not in users
10381065
1039 def filter_evts(events):
1040 return filter_events_for_client(
1066 async def filter_evts(events):
1067 if use_admin_priviledge:
1068 return events
1069 return await filter_events_for_client(
10411070 self.storage, user.to_string(), events, is_peeking=is_peeking
10421071 )
10431072
190190 # do it up front for efficiency.)
191191 if txn_id and requester.access_token_id:
192192 existing_event_id = await self.store.get_event_id_from_transaction_id(
193 room_id, requester.user.to_string(), requester.access_token_id, txn_id,
193 room_id,
194 requester.user.to_string(),
195 requester.access_token_id,
196 txn_id,
194197 )
195198 if existing_event_id:
196199 event_pos = await self.store.get_position_for_event(existing_event_id)
237240 )
238241
239242 result_event = await self.event_creation_handler.handle_new_client_event(
240 requester, event, context, extra_users=[target], ratelimit=ratelimit,
243 requester,
244 event,
245 context,
246 extra_users=[target],
247 ratelimit=ratelimit,
241248 )
242249
243250 if event.membership == Membership.LEAVE:
582589 # send the rejection to the inviter's HS (with fallback to
583590 # local event)
584591 return await self.remote_reject_invite(
585 invite.event_id, txn_id, requester, content,
592 invite.event_id,
593 txn_id,
594 requester,
595 content,
586596 )
587597
588598 # the inviter was on our server, but has now left. Carry on
10551065 user: UserID,
10561066 content: dict,
10571067 ) -> Tuple[str, int]:
1058 """Implements RoomMemberHandler._remote_join
1059 """
1068 """Implements RoomMemberHandler._remote_join"""
10601069 # filter ourselves out of remote_room_hosts: do_invite_join ignores it
10611070 # and if it is the only entry we'd like to return a 404 rather than a
10621071 # 500.
12101219 event.internal_metadata.out_of_band_membership = True
12111220
12121221 result_event = await self.event_creation_handler.handle_new_client_event(
1213 requester, event, context, extra_users=[UserID.from_string(target_user)],
1222 requester,
1223 event,
1224 context,
1225 extra_users=[UserID.from_string(target_user)],
12141226 )
12151227 # we know it was persisted, so must have a stream ordering
12161228 assert result_event.internal_metadata.stream_ordering
12181230 return result_event.event_id, result_event.internal_metadata.stream_ordering
12191231
12201232 async def _user_left_room(self, target: UserID, room_id: str) -> None:
1221 """Implements RoomMemberHandler._user_left_room
1222 """
1233 """Implements RoomMemberHandler._user_left_room"""
12231234 user_left_room(self.distributor, target, room_id)
12241235
12251236 async def forget(self, user: UserID, room_id: str) -> None:
4343 user: UserID,
4444 content: dict,
4545 ) -> Tuple[str, int]:
46 """Implements RoomMemberHandler._remote_join
47 """
46 """Implements RoomMemberHandler._remote_join"""
4847 if len(remote_room_hosts) == 0:
4948 raise SynapseError(404, "No known servers")
5049
7978 return ret["event_id"], ret["stream_id"]
8079
8180 async def _user_left_room(self, target: UserID, room_id: str) -> None:
82 """Implements RoomMemberHandler._user_left_room
83 """
81 """Implements RoomMemberHandler._user_left_room"""
8482 await self._notify_change_client(
8583 user_id=target.to_string(), room_id=room_id, change="left"
8684 )
2222
2323 from synapse.api.errors import SynapseError
2424 from synapse.config import ConfigError
25 from synapse.config.saml2_config import SamlAttributeRequirement
2625 from synapse.handlers._base import BaseHandler
2726 from synapse.handlers.sso import MappingException, UserAttributes
2827 from synapse.http.servlet import parse_string
121120
122121 now = self.clock.time_msec()
123122 self._outstanding_requests_dict[reqid] = Saml2SessionData(
124 creation_time=now, ui_auth_session_id=ui_auth_session_id,
123 creation_time=now,
124 ui_auth_session_id=ui_auth_session_id,
125125 )
126126
127127 for key, value in info["headers"]:
238238
239239 # Ensure that the attributes of the logged in user meet the required
240240 # attributes.
241 for requirement in self._saml2_attribute_requirements:
242 if not _check_attribute_requirement(saml2_auth.ava, requirement):
243 self._sso_handler.render_error(
244 request, "unauthorised", "You are not authorised to log in here."
245 )
246 return
241 if not self._sso_handler.check_required_attributes(
242 request, saml2_auth.ava, self._saml2_attribute_requirements
243 ):
244 return
247245
248246 # Call the mapper to register/login the user
249247 try:
372370 del self._outstanding_requests_dict[reqid]
373371
374372
375 def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
376 values = ava.get(req.attribute, [])
377 for v in values:
378 if v == req.value:
379 return True
380
381 logger.info(
382 "SAML2 attribute %s did not match required value '%s' (was '%s')",
383 req.attribute,
384 req.value,
385 values,
386 )
387 return False
388
389
390373 DOT_REPLACE_PATTERN = re.compile(
391374 ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
392375 )
467450 mxid_source = saml_response.ava[self._mxid_source_attribute][0]
468451 except KeyError:
469452 logger.warning(
470 "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
453 "SAML2 response lacks a '%s' attestation",
454 self._mxid_source_attribute,
471455 )
472456 raise SynapseError(
473457 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
1515 import logging
1616 from typing import (
1717 TYPE_CHECKING,
18 Any,
1819 Awaitable,
1920 Callable,
2021 Dict,
2122 Iterable,
23 List,
2224 Mapping,
2325 Optional,
2426 Set,
3335
3436 from synapse.api.constants import LoginType
3537 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
38 from synapse.config.sso import SsoAttributeRequirement
3639 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
3740 from synapse.http import get_request_user_agent
3841 from synapse.http.server import respond_with_html, respond_with_redirect
323326
324327 # Check if we already have a mapping for this user.
325328 previously_registered_user_id = await self._store.get_user_by_external_id(
326 auth_provider_id, remote_user_id,
329 auth_provider_id,
330 remote_user_id,
327331 )
328332
329333 # A match was found, return the user ID.
412416 with await self._mapping_lock.queue(auth_provider_id):
413417 # first of all, check if we already have a mapping for this user
414418 user_id = await self.get_sso_user_by_remote_user_id(
415 auth_provider_id, remote_user_id,
419 auth_provider_id,
420 remote_user_id,
416421 )
417422
418423 # Check for grandfathering of users.
457462 )
458463
459464 async def _call_attribute_mapper(
460 self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
465 self,
466 sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
461467 ) -> UserAttributes:
462468 """Call the attribute mapper function in a loop, until we get a unique userid"""
463469 for i in range(self._MAP_USERNAME_RETRIES):
628634 """
629635
630636 user_id = await self.get_sso_user_by_remote_user_id(
631 auth_provider_id, remote_user_id,
637 auth_provider_id,
638 remote_user_id,
632639 )
633640
634641 user_id_to_verify = await self._auth_handler.get_session_data(
667674
668675 # render an error page.
669676 html = self._bad_user_template.render(
670 server_name=self._server_name, user_id_to_verify=user_id_to_verify,
677 server_name=self._server_name,
678 user_id_to_verify=user_id_to_verify,
671679 )
672680 respond_with_html(request, 200, html)
673681
691699 raise SynapseError(400, "unknown session")
692700
693701 async def check_username_availability(
694 self, localpart: str, session_id: str,
702 self,
703 localpart: str,
704 session_id: str,
695705 ) -> bool:
696706 """Handle an "is username available" callback check
697707
741751 use_display_name: whether the user wants to use the suggested display name
742752 emails_to_use: emails that the user would like to use
743753 """
744 session = self.get_mapping_session(session_id)
754 try:
755 session = self.get_mapping_session(session_id)
756 except SynapseError as e:
757 self.render_error(request, "bad_session", e.msg, code=e.code)
758 return
745759
746760 # update the session with the user's choices
747761 session.chosen_localpart = localpart
792806 session_id,
793807 terms_version,
794808 )
795 session = self.get_mapping_session(session_id)
809 try:
810 session = self.get_mapping_session(session_id)
811 except SynapseError as e:
812 self.render_error(request, "bad_session", e.msg, code=e.code)
813 return
814
796815 session.terms_accepted_version = terms_version
797816
798817 # we're done; now we can register the user
807826 request: HTTP request
808827 session_id: ID of the username mapping session, extracted from a cookie
809828 """
810 session = self.get_mapping_session(session_id)
829 try:
830 session = self.get_mapping_session(session_id)
831 except SynapseError as e:
832 self.render_error(request, "bad_session", e.msg, code=e.code)
833 return
811834
812835 logger.info(
813836 "[session %s] Registering localpart %s",
816839 )
817840
818841 attributes = UserAttributes(
819 localpart=session.chosen_localpart, emails=session.emails_to_use,
842 localpart=session.chosen_localpart,
843 emails=session.emails_to_use,
820844 )
821845
822846 if session.use_display_name:
879903 logger.info("Expiring mapping session %s", session_id)
880904 del self._username_mapping_sessions[session_id]
881905
906 def check_required_attributes(
907 self,
908 request: SynapseRequest,
909 attributes: Mapping[str, List[Any]],
910 attribute_requirements: Iterable[SsoAttributeRequirement],
911 ) -> bool:
912 """
913 Confirm that the required attributes were present in the SSO response.
914
915 If all requirements are met, this will return True.
916
917 If any requirement is not met, then the request will be finalized by
918 showing an error page to the user and False will be returned.
919
920 Args:
921 request: The request to (potentially) respond to.
922 attributes: The attributes from the SSO IdP.
923 attribute_requirements: The requirements that attributes must meet.
924
925 Returns:
926 True if all requirements are met, False if any attribute fails to
927 meet the requirement.
928
929 """
930 # Ensure that the attributes of the logged in user meet the required
931 # attributes.
932 for requirement in attribute_requirements:
933 if not _check_attribute_requirement(attributes, requirement):
934 self.render_error(
935 request, "unauthorised", "You are not authorised to log in here."
936 )
937 return False
938
939 return True
940
882941
883942 def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
884943 """Extract the session ID from the cookie
889948 if not session_id:
890949 raise SynapseError(code=400, msg="missing session_id")
891950 return session_id.decode("ascii", errors="replace")
951
952
953 def _check_attribute_requirement(
954 attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
955 ) -> bool:
956 """Check if SSO attributes meet the proper requirements.
957
958 Args:
959 attributes: A mapping of attributes to an iterable of one or more values.
960 requirement: The configured requirement to check.
961
962 Returns:
963 True if the required attribute was found and had a proper value.
964 """
965 if req.attribute not in attributes:
966 logger.info("SSO attribute missing: %s", req.attribute)
967 return False
968
969 # If the requirement is None, the attribute existing is enough.
970 if req.value is None:
971 return True
972
973 values = attributes[req.attribute]
974 if req.value in values:
975 return True
976
977 logger.info(
978 "SSO attribute %s did not match required value '%s' (was '%s')",
979 req.attribute,
980 req.value,
981 values,
982 )
983 return False
6262 self.clock.call_later(0, self.notify_new_event)
6363
6464 def notify_new_event(self) -> None:
65 """Called when there may be more deltas to process
66 """
65 """Called when there may be more deltas to process"""
6766 if not self.stats_enabled or self._is_processing:
6867 return
6968
338338 since_token: Optional[StreamToken] = None,
339339 full_state: bool = False,
340340 ) -> SyncResult:
341 """Get the sync for client needed to match what the server has now.
342 """
341 """Get the sync for client needed to match what the server has now."""
343342 return await self.generate_sync_result(sync_config, since_token, full_state)
344343
345344 async def push_rules_for_user(self, user: UserID) -> JsonDict:
563562 stream_position: StreamToken,
564563 state_filter: StateFilter = StateFilter.all(),
565564 ) -> StateMap[str]:
566 """ Get the room state at a particular stream position
565 """Get the room state at a particular stream position
567566
568567 Args:
569568 room_id: room for which to get state
597596 state: MutableStateMap[EventBase],
598597 now_token: StreamToken,
599598 ) -> Optional[JsonDict]:
600 """ Works out a room summary block for this room, summarising the number
599 """Works out a room summary block for this room, summarising the number
601600 of joined members in the room, and providing the 'hero' members if the
602601 room has no name so clients can consistently name rooms. Also adds
603602 state events to 'state' if needed to describe the heroes.
742741 now_token: StreamToken,
743742 full_state: bool,
744743 ) -> MutableStateMap[EventBase]:
745 """ Works out the difference in state between the start of the timeline
744 """Works out the difference in state between the start of the timeline
746745 and the previous sync.
747746
748747 Args:
819818 )
820819 elif batch.limited:
821820 if batch:
822 state_at_timeline_start = await self.state_store.get_state_ids_for_event(
823 batch.events[0].event_id, state_filter=state_filter
821 state_at_timeline_start = (
822 await self.state_store.get_state_ids_for_event(
823 batch.events[0].event_id, state_filter=state_filter
824 )
824825 )
825826 else:
826827 # We can get here if the user has ignored the senders of all
954955 since_token: Optional[StreamToken] = None,
955956 full_state: bool = False,
956957 ) -> SyncResult:
957 """Generates a sync result.
958 """
958 """Generates a sync result."""
959959 # NB: The now_token gets changed by some of the generate_sync_* methods,
960960 # this is due to some of the underlying streams not supporting the ability
961961 # to query up to a given point.
10291029 one_time_key_counts = await self.store.count_e2e_one_time_keys(
10301030 user_id, device_id
10311031 )
1032 unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
1033 user_id, device_id
1032 unused_fallback_key_types = (
1033 await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
10341034 )
10351035
10361036 logger.debug("Fetching group data")
11751175 # weren't in the previous sync *or* they left and rejoined.
11761176 users_that_have_changed.update(newly_joined_or_invited_users)
11771177
1178 user_signatures_changed = await self.store.get_users_whose_signatures_changed(
1179 user_id, since_token.device_list_key
1178 user_signatures_changed = (
1179 await self.store.get_users_whose_signatures_changed(
1180 user_id, since_token.device_list_key
1181 )
11801182 )
11811183 users_that_have_changed.update(user_signatures_changed)
11821184
13921394 logger.debug("no-oping sync")
13931395 return set(), set(), set(), set()
13941396
1395 ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
1396 AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
1397 ignored_account_data = (
1398 await self.store.get_global_account_data_by_type_for_user(
1399 AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
1400 )
13971401 )
13981402
13991403 # If there is ignored users account data and it matches the proper type,
14981502 async def _get_rooms_changed(
14991503 self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
15001504 ) -> _RoomChanges:
1501 """Gets the the changes that have happened since the last sync.
1502 """
1505 """Gets the the changes that have happened since the last sync."""
15031506 user_id = sync_result_builder.sync_config.user.to_string()
15041507 since_token = sync_result_builder.since_token
15051508 now_token = sync_result_builder.now_token
6060
6161 if hs.config.worker.writers.typing != hs.get_instance_name():
6262 hs.get_federation_registry().register_instance_for_edu(
63 "m.typing", hs.config.worker.writers.typing,
63 "m.typing",
64 hs.config.worker.writers.typing,
6465 )
6566
6667 # map room IDs to serial numbers
7576 self.clock.looping_call(self._handle_timeouts, 5000)
7677
7778 def _reset(self) -> None:
78 """Reset the typing handler's data caches.
79 """
79 """Reset the typing handler's data caches."""
8080 # map room IDs to serial numbers
8181 self._room_serials = {}
8282 # map room IDs to sets of users currently typing
148148 def process_replication_rows(
149149 self, token: int, rows: List[TypingStream.TypingStreamRow]
150150 ) -> None:
151 """Should be called whenever we receive updates for typing stream.
152 """
151 """Should be called whenever we receive updates for typing stream."""
153152
154153 if self._latest_room_serial > token:
155154 # The master has gone backwards. To prevent inconsistent data, just
9696 return results
9797
9898 def notify_new_event(self) -> None:
99 """Called when there may be more deltas to process
100 """
99 """Called when there may be more deltas to process"""
101100 if not self.update_user_directory:
102101 return
103102
133132 )
134133
135134 async def handle_user_deactivated(self, user_id: str) -> None:
136 """Called when a user ID is deactivated
137 """
135 """Called when a user ID is deactivated"""
138136 # FIXME(#3714): We should probably do this in the same worker as all
139137 # the other changes.
140138 await self.store.remove_from_user_dir(user_id)
143141 # If self.pos is None then means we haven't fetched it from DB
144142 if self.pos is None:
145143 self.pos = await self.store.get_user_directory_stream_pos()
144
145 # If still None then the initial background update hasn't happened yet.
146 if self.pos is None:
147 return None
146148
147149 # Loop round handling deltas until we're up to date
148150 while True:
171173 await self.store.update_user_directory_stream_pos(max_pos)
172174
173175 async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
174 """Called with the state deltas to process
175 """
176 """Called with the state deltas to process"""
176177 for delta in deltas:
177178 typ = delta["type"]
178179 state_key = delta["state_key"]
5353
5454
5555 def get_request_user_agent(request: IRequest, default: str = "") -> str:
56 """Return the last User-Agent header, or the given default.
57 """
56 """Return the last User-Agent header, or the given default."""
5857 # There could be raw utf-8 bytes in the User-Agent header.
5958
6059 # N.B. if you don't do this, the logger explodes cryptically
5555 )
5656 from twisted.web.http import PotentialDataLoss
5757 from twisted.web.http_headers import Headers
58 from twisted.web.iweb import IAgent, IBodyProducer, IResponse
58 from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
5959
6060 from synapse.api.errors import Codes, HttpResponseException, SynapseError
6161 from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
397397 body_producer = None
398398 if data is not None:
399399 body_producer = QuieterFileBodyProducer(
400 BytesIO(data), cooperator=self._cooperator,
400 BytesIO(data),
401 cooperator=self._cooperator,
401402 )
402403
403404 request_deferred = treq.request(
406407 agent=self.agent,
407408 data=body_producer,
408409 headers=headers,
410 # Avoid buffering the body in treq since we do not reuse
411 # response bodies.
412 unbuffered=True,
409413 **self._extra_treq_args,
410414 ) # type: defer.Deferred
411415
412416 # we use our own timeout mechanism rather than treq's as a workaround
413417 # for https://twistedmatrix.com/trac/ticket/9534.
414418 request_deferred = timeout_deferred(
415 request_deferred, 60, self.hs.get_reactor(),
419 request_deferred,
420 60,
421 self.hs.get_reactor(),
416422 )
417423
418424 # turn timeouts into RequestTimedOutErrors
698704
699705 resp_headers = dict(response.headers.getAllRawHeaders())
700706
701 if (
702 b"Content-Length" in resp_headers
703 and max_size
704 and int(resp_headers[b"Content-Length"][0]) > max_size
705 ):
706 logger.warning("Requested URL is too large > %r bytes" % (max_size,))
707 raise SynapseError(
708 502,
709 "Requested file is too large > %r bytes" % (max_size,),
710 Codes.TOO_LARGE,
711 )
712
713707 if response.code > 299:
714708 logger.warning("Got %d when downloading %s" % (response.code, url))
715709 raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
776770 # in the meantime.
777771 if self.max_size is not None and self.length >= self.max_size:
778772 self.deferred.errback(BodyExceededMaxSize())
779 self.transport.loseConnection()
773 # Close the connection (forcefully) since all the data will get
774 # discarded anyway.
775 self.transport.abortConnection()
780776
781777 def connectionLost(self, reason: Failure) -> None:
782778 # If the maximum size was already exceeded, there's nothing to do.
810806 Returns:
811807 A Deferred which resolves to the length of the read body.
812808 """
809 # If the Content-Length header gives a size larger than the maximum allowed
810 # size, do not bother downloading the body.
811 if max_size is not None and response.length != UNKNOWN_LENGTH:
812 if response.length > max_size:
813 return defer.fail(BodyExceededMaxSize())
813814
814815 d = defer.Deferred()
815816 response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
194194
195195 @implementer(IAgentEndpointFactory)
196196 class MatrixHostnameEndpointFactory:
197 """Factory for MatrixHostnameEndpoint for parsing to an Agent.
198 """
197 """Factory for MatrixHostnameEndpoint for parsing to an Agent."""
199198
200199 def __init__(
201200 self,
260259 self._srv_resolver = srv_resolver
261260
262261 def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
263 """Implements IStreamClientEndpoint interface
264 """
262 """Implements IStreamClientEndpoint interface"""
265263
266264 return run_in_background(self._do_connect, protocol_factory)
267265
322320 if port or _is_ip_literal(host):
323321 return [Server(host, port or 8448)]
324322
323 logger.debug("Looking up SRV record for %s", host.decode(errors="replace"))
325324 server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
326325
327326 if server_list:
327 logger.debug(
328 "Got %s from SRV lookup for %s",
329 ", ".join(map(str, server_list)),
330 host.decode(errors="replace"),
331 )
328332 return server_list
329333
330334 # No SRV records, so we fallback to host and 8448
335 logger.debug("No SRV records for %s", host.decode(errors="replace"))
331336 return [Server(host, 8448)]
332337
333338
8080
8181
8282 class WellKnownResolver:
83 """Handles well-known lookups for matrix servers.
84 """
83 """Handles well-known lookups for matrix servers."""
8584
8685 def __init__(
8786 self,
253253 # Use a BlacklistingAgentWrapper to prevent circumventing the IP
254254 # blacklist via IP literals in server names
255255 self.agent = BlacklistingAgentWrapper(
256 self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
256 self.agent,
257 ip_blacklist=hs.config.federation_ip_range_blacklist,
257258 )
258259
259260 self.clock = hs.get_clock()
651652 backoff_on_404: bool = False,
652653 try_trailing_slash_on_400: bool = False,
653654 ) -> Union[JsonDict, list]:
654 """ Sends the specified json data using PUT
655 """Sends the specified json data using PUT
655656
656657 Args:
657658 destination: The remote server to send the HTTP request to.
739740 ignore_backoff: bool = False,
740741 args: Optional[QueryArgs] = None,
741742 ) -> Union[JsonDict, list]:
742 """ Sends the specified json data using POST
743 """Sends the specified json data using POST
743744
744745 Args:
745746 destination: The remote server to send the HTTP request to.
798799 _sec_timeout = self.default_timeout
799800
800801 body = await _handle_json_response(
801 self.reactor, _sec_timeout, request, response, start_ms,
802 self.reactor,
803 _sec_timeout,
804 request,
805 response,
806 start_ms,
802807 )
803808 return body
804809
812817 ignore_backoff: bool = False,
813818 try_trailing_slash_on_400: bool = False,
814819 ) -> Union[JsonDict, list]:
815 """ GETs some json from the given host homeserver and path
820 """GETs some json from the given host homeserver and path
816821
817822 Args:
818823 destination: The remote server to send the HTTP request to.
993998 except BodyExceededMaxSize:
994999 msg = "Requested file is too large > %r bytes" % (max_size,)
9951000 logger.warning(
996 "{%s} [%s] %s", request.txn_id, request.destination, msg,
1001 "{%s} [%s] %s",
1002 request.txn_id,
1003 request.destination,
1004 msg,
9971005 )
9981006 raise SynapseError(502, msg, Codes.TOO_LARGE)
9991007 except Exception as e:
212212 self.update_metrics()
213213
214214 def update_metrics(self):
215 """Updates the in flight metrics with values from this request.
216 """
215 """Updates the in flight metrics with values from this request."""
217216 new_stats = self.start_context.get_resource_usage()
218217
219218 diff = new_stats - self._request_stats
7575
7676
7777 def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
78 """Sends a JSON error response to clients.
79 """
78 """Sends a JSON error response to clients."""
8079
8180 if f.check(SynapseError):
8281 error_code = f.value.code
105104 pass
106105 else:
107106 respond_with_json(
108 request, error_code, error_dict, send_cors=True,
107 request,
108 error_code,
109 error_dict,
110 send_cors=True,
109111 )
110112
111113
112114 def return_html_error(
113 f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
115 f: failure.Failure,
116 request: Request,
117 error_template: Union[str, jinja2.Template],
114118 ) -> None:
115119 """Sends an HTML error page corresponding to the given failure.
116120
188192
189193
190194 class HttpServer(Protocol):
191 """ Interface for registering callbacks on a HTTP server
192 """
195 """Interface for registering callbacks on a HTTP server"""
193196
194197 def register_paths(
195198 self,
198201 callback: ServletCallback,
199202 servlet_classname: str,
200203 ) -> None:
201 """ Register a callback that gets fired if we receive a http request
204 """Register a callback that gets fired if we receive a http request
202205 with the given method for a path that matches the given regex.
203206
204207 If the regex contains groups these gets passed to the callback via
234237 self._extract_context = extract_context
235238
236239 def render(self, request):
237 """ This gets called by twisted every time someone sends us a request.
238 """
240 """This gets called by twisted every time someone sends us a request."""
239241 defer.ensureDeferred(self._async_render_wrapper(request))
240242 return NOT_DONE_YET
241243
286288
287289 @abc.abstractmethod
288290 def _send_response(
289 self, request: SynapseRequest, code: int, response_object: Any,
291 self,
292 request: SynapseRequest,
293 code: int,
294 response_object: Any,
290295 ) -> None:
291296 raise NotImplementedError()
292297
293298 @abc.abstractmethod
294299 def _send_error_response(
295 self, f: failure.Failure, request: SynapseRequest,
300 self,
301 f: failure.Failure,
302 request: SynapseRequest,
296303 ) -> None:
297304 raise NotImplementedError()
298305
307314 self.canonical_json = canonical_json
308315
309316 def _send_response(
310 self, request: Request, code: int, response_object: Any,
317 self,
318 request: Request,
319 code: int,
320 response_object: Any,
311321 ):
312 """Implements _AsyncResource._send_response
313 """
322 """Implements _AsyncResource._send_response"""
314323 # TODO: Only enable CORS for the requests that need it.
315324 respond_with_json(
316325 request,
321330 )
322331
323332 def _send_error_response(
324 self, f: failure.Failure, request: SynapseRequest,
333 self,
334 f: failure.Failure,
335 request: SynapseRequest,
325336 ) -> None:
326 """Implements _AsyncResource._send_error_response
327 """
337 """Implements _AsyncResource._send_error_response"""
328338 return_json_error(f, request)
329339
330340
331341 class JsonResource(DirectServeJsonResource):
332 """ This implements the HttpServer interface and provides JSON support for
342 """This implements the HttpServer interface and provides JSON support for
333343 Resources.
334344
335345 Register callbacks via register_paths()
442452 ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
443453
444454 def _send_response(
445 self, request: SynapseRequest, code: int, response_object: Any,
455 self,
456 request: SynapseRequest,
457 code: int,
458 response_object: Any,
446459 ):
447 """Implements _AsyncResource._send_response
448 """
460 """Implements _AsyncResource._send_response"""
449461 # We expect to get bytes for us to write
450462 assert isinstance(response_object, bytes)
451463 html_bytes = response_object
453465 respond_with_html_bytes(request, 200, html_bytes)
454466
455467 def _send_error_response(
456 self, f: failure.Failure, request: SynapseRequest,
468 self,
469 f: failure.Failure,
470 request: SynapseRequest,
457471 ) -> None:
458 """Implements _AsyncResource._send_error_response
459 """
472 """Implements _AsyncResource._send_error_response"""
460473 return_html_error(f, request, self.ERROR_TEMPLATE)
461474
462475
533546 min_chunk_size = 1024
534547
535548 def __init__(
536 self, request: Request, iterator: Iterator[bytes],
549 self,
550 request: Request,
551 iterator: Iterator[bytes],
537552 ):
538553 self._request = request
539554 self._iterator = iterator
653668
654669
655670 def respond_with_json_bytes(
656 request: Request, code: int, json_bytes: bytes, send_cors: bool = False,
671 request: Request,
672 code: int,
673 json_bytes: bytes,
674 send_cors: bool = False,
657675 ):
658676 """Sends encoded JSON in response to the given request.
659677
768786
769787
770788 def finish_request(request: Request):
771 """ Finish writing the response to the request.
789 """Finish writing the response to the request.
772790
773791 Twisted throws a RuntimeException if the connection closed before the
774792 response was written but doesn't provide a convenient or reliable way to
257257
258258 class RestServlet:
259259
260 """ A Synapse REST Servlet.
260 """A Synapse REST Servlet.
261261
262262 An implementing class can either provide its own custom 'register' method,
263263 or use the automatic pattern handling provided by the base class.
248248 )
249249
250250 def _finished_processing(self):
251 """Log the completion of this request and update the metrics
252 """
251 """Log the completion of this request and update the metrics"""
253252 assert self.logcontext is not None
254253 usage = self.logcontext.get_resource_usage()
255254
275274 # authenticated (e.g. and admin is puppetting a user) then we log both.
276275 if self.requester.user.to_string() != authenticated_entity:
277276 authenticated_entity = "{},{}".format(
278 authenticated_entity, self.requester.user.to_string(),
277 authenticated_entity,
278 self.requester.user.to_string(),
279279 )
280280 elif self.requester is not None:
281281 # This shouldn't happen, but we log it so we don't lose information
321321 logger.warning("Failed to stop metrics: %r", e)
322322
323323 def _should_log_request(self) -> bool:
324 """Whether we should log at INFO that we processed the request.
325 """
324 """Whether we should log at INFO that we processed the request."""
326325 if self.path == b"/health":
327326 return False
328327
173173
174174 # Make a new producer and start it.
175175 self._producer = LogProducer(
176 buffer=self._buffer, transport=result.transport, format=self.format,
176 buffer=self._buffer,
177 transport=result.transport,
178 format=self.format,
177179 )
178180 result.transport.registerProducer(self._producer, True)
179181 self._producer.resumeProducing()
5959 )
6060
6161 # Either use the default formatter or the tersejson one.
62 if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
62 if logging_type in (
63 DrainType.CONSOLE_JSON,
64 DrainType.FILE_JSON,
65 ):
6366 formatter = "json" # type: Optional[str]
6467 elif logging_type in (
6568 DrainType.CONSOLE_JSON_TERSE,
130133 )
131134
132135
133 def setup_structured_logging(log_config: dict,) -> dict:
136 def setup_structured_logging(
137 log_config: dict,
138 ) -> dict:
134139 """
135140 Convert a legacy structured logging configuration (from Synapse < v1.23.0)
136141 to one compatible with the new standard library handlers.
337337 if self.previous_context != old_context:
338338 logcontext_error(
339339 "Expected previous context %r, found %r"
340 % (self.previous_context, old_context,)
340 % (
341 self.previous_context,
342 old_context,
343 )
341344 )
342345 return self
343346
561564 class PreserveLoggingContext:
562565 """Context manager which replaces the logging context
563566
564 The previous logging context is restored on exit."""
567 The previous logging context is restored on exit."""
565568
566569 __slots__ = ["_old_context", "_new_context"]
567570
584587 else:
585588 logcontext_error(
586589 "Expected logging context %s but found %s"
587 % (self._new_context, context,)
590 % (
591 self._new_context,
592 context,
593 )
588594 )
589595
590596
237237
238238 @attr.s(slots=True, frozen=True)
239239 class _WrappedRustReporter:
240 """Wrap the reporter to ensure `report_span` never throws.
241 """
240 """Wrap the reporter to ensure `report_span` never throws."""
242241
243242 _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
244243
325324
326325
327326 def init_tracer(hs: "HomeServer"):
328 """Set the whitelists and initialise the JaegerClient tracer
329 """
327 """Set the whitelists and initialise the JaegerClient tracer"""
330328 global opentracing
331329 if not hs.config.opentracer_enabled:
332330 # We don't have a tracer
383381
384382 Args:
385383 destination (str)
386 """
384 """
387385
388386 if _homeserver_whitelist:
389387 return _homeserver_whitelist.match(destination)
4242
4343
4444 def log_function(f):
45 """ Function decorator that logs every call to that function.
46 """
45 """Function decorator that logs every call to that function."""
4746 func_name = f.__name__
4847
4948 @wraps(f)
154154 self._registrations.setdefault(key, set()).add(callback)
155155
156156 def unregister(self, key, callback):
157 """Registers that we've exited a block with labels `key`.
158 """
157 """Registers that we've exited a block with labels `key`."""
159158
160159 with self._lock:
161160 self._registrations.setdefault(key, set()).discard(callback)
401400 # Total time spent in GC: 0.073 # s.total_gc_time
402401
403402 pypy_gc_time = CounterMetricFamily(
404 "pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[],
403 "pypy_gc_time_seconds_total",
404 "Total time spent in PyPy GC",
405 labels=[],
405406 )
406407 pypy_gc_time.add_metric([], s.total_gc_time / 1000)
407408 yield pypy_gc_time
215215 @classmethod
216216 def factory(cls, registry):
217217 """Returns a dynamic MetricsHandler class tied
218 to the passed registry.
218 to the passed registry.
219219 """
220220 # This implementation relies on MetricsHandler.registry
221221 # (defined above and defaulted to REGISTRY).
207207 return await maybe_awaitable(func(*args, **kwargs))
208208 except Exception:
209209 logger.exception(
210 "Background process '%s' threw an exception", desc,
210 "Background process '%s' threw an exception",
211 desc,
211212 )
212213 finally:
213214 _background_process_in_flight_count.labels(desc).dec()
248249 self._proc = _BackgroundProcess(name, self)
249250
250251 def start(self, rusage: "Optional[resource._RUsage]"):
251 """Log context has started running (again).
252 """
252 """Log context has started running (again)."""
253253
254254 super().start(rusage)
255255
260260 _background_processes_active_since_last_scrape.add(self._proc)
261261
262262 def __exit__(self, type, value, traceback) -> None:
263 """Log context has finished.
264 """
263 """Log context has finished."""
265264
266265 super().__exit__(type, value, traceback)
267266
274274 redirect them directly if whitelisted).
275275 """
276276 self._auth_handler._complete_sso_login(
277 registered_user_id, request, client_redirect_url,
277 registered_user_id,
278 request,
279 client_redirect_url,
278280 )
279281
280282 async def complete_sso_login_async(
351353 event,
352354 _,
353355 ) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event(
354 requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
356 requester,
357 event_dict,
358 ratelimit=False,
359 ignore_shadow_ban=True,
355360 )
356361
357362 return event
7474
7575
7676 class _NotificationListener:
77 """ This represents a single client connection to the events stream.
77 """This represents a single client connection to the events stream.
7878 The events stream handler will have yielded to the deferred, so to
7979 notify the handler it is sufficient to resolve the deferred.
8080 """
118118 self.notify_deferred = ObservableDeferred(defer.Deferred())
119119
120120 def notify(
121 self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
121 self,
122 stream_key: str,
123 stream_id: Union[int, RoomStreamToken],
124 time_now_ms: int,
122125 ):
123126 """Notify any listeners for this user of a new event from an
124127 event source.
139142 noify_deferred.callback(self.current_token)
140143
141144 def remove(self, notifier: "Notifier"):
142 """ Remove this listener from all the indexes in the Notifier
145 """Remove this listener from all the indexes in the Notifier
143146 it knows about.
144147 """
145148
185188
186189
187190 class Notifier:
188 """ This class is responsible for notifying any listeners when there are
191 """This class is responsible for notifying any listeners when there are
189192 new events available for it.
190193
191194 Primarily used from the /events stream.
264267 max_room_stream_token: RoomStreamToken,
265268 extra_users: Collection[UserID] = [],
266269 ):
267 """Unwraps event and calls `on_new_room_event_args`.
268 """
270 """Unwraps event and calls `on_new_room_event_args`."""
269271 self.on_new_room_event_args(
270272 event_pos=event_pos,
271273 room_id=event.room_id,
340342
341343 if users or rooms:
342344 self.on_new_event(
343 "room_key", max_room_stream_token, users=users, rooms=rooms,
345 "room_key",
346 max_room_stream_token,
347 users=users,
348 rooms=rooms,
344349 )
345350 self._on_updated_room_token(max_room_stream_token)
346351
391396 users: Collection[Union[str, UserID]] = [],
392397 rooms: Collection[str] = [],
393398 ):
394 """ Used to inform listeners that something has happened event wise.
399 """Used to inform listeners that something has happened event wise.
395400
396401 Will wake up all listeners for the given users and rooms.
397402 """
417422
418423 # Notify appservices
419424 self._notify_app_services_ephemeral(
420 stream_key, new_token, users,
425 stream_key,
426 new_token,
427 users,
421428 )
422429
423430 def on_new_replication_data(self) -> None:
501508 is_guest: bool = False,
502509 explicit_room_id: str = None,
503510 ) -> EventStreamResult:
504 """ For the given user and rooms, return any new events for them. If
511 """For the given user and rooms, return any new events for them. If
505512 there are no new events wait for up to `timeout` milliseconds for any
506513 new events to happen before returning.
507514
650657 cb()
651658
652659 def notify_remote_server_up(self, server: str):
653 """Notify any replication that a remote server has come back up
654 """
660 """Notify any replication that a remote server has come back up"""
655661 # We call federation_sender directly rather than registering as a
656662 # callback as a) we already have a reference to it and b) it introduces
657663 # circular dependencies.
143143
144144 @lru_cache()
145145 def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
146 """Get the current RulesForRoom object for the given room id
147 """
146 """Get the current RulesForRoom object for the given room id"""
148147 # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
149148 # before any lookup methods get called on it as otherwise there may be
150149 # a race if invalidate_all gets called (which assumes its in the cache)
251250 # notified for this event. (This will then get handled when we persist
252251 # the event)
253252 await self.store.add_push_actions_to_staging(
254 event.event_id, actions_by_user, count_as_unread,
253 event.event_id,
254 actions_by_user,
255 count_as_unread,
255256 )
256257
257258
523524 class _Invalidation:
524525 # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
525526 # which means that it it is stored on the bulk_get_push_rules cache entry. In order
526 # to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
527 # to ensure that we don't accumulate lots of redundant callbacks on the cache entry,
527528 # we need to ensure that two _Invalidation objects are "equal" if they refer to the
528529 # same `cache` and `room_id`.
529530 #
115115 self._is_processing = True
116116
117117 def _resume_processing(self) -> None:
118 """Used by tests to resume processing of events after pausing.
119 """
118 """Used by tests to resume processing of events after pausing."""
120119 assert self._is_processing
121120 self._is_processing = False
122121 self._start_processing()
156155 being run.
157156 """
158157 start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
159 unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
160 self.user_id, start, self.max_stream_ordering
158 unprocessed = (
159 await self.store.get_unread_push_actions_for_user_in_range_for_email(
160 self.user_id, start, self.max_stream_ordering
161 )
161162 )
162163
163164 soonest_due_at = None # type: Optional[int]
221222 self, last_stream_ordering: int
222223 ) -> None:
223224 self.last_stream_ordering = last_stream_ordering
224 pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
225 self.app_id,
226 self.email,
227 self.user_id,
228 last_stream_ordering,
229 self.clock.time_msec(),
225 pusher_still_exists = (
226 await self.store.update_pusher_last_stream_ordering_and_success(
227 self.app_id,
228 self.email,
229 self.user_id,
230 last_stream_ordering,
231 self.clock.time_msec(),
232 )
230233 )
231234 if not pusher_still_exists:
232235 # The pusher has been deleted while we were processing, so
297300 current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
298301 )
299302 self.throttle_params[room_id] = ThrottleParams(
300 self.clock.time_msec(), new_throttle_ms,
303 self.clock.time_msec(),
304 new_throttle_ms,
301305 )
302306 assert self.pusher_id is not None
303307 await self.store.set_throttle_params(
175175 Never call this directly: use _process which will only allow this to
176176 run once per pusher.
177177 """
178 unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
179 self.user_id, self.last_stream_ordering, self.max_stream_ordering
178 unprocessed = (
179 await self.store.get_unread_push_actions_for_user_in_range_for_http(
180 self.user_id, self.last_stream_ordering, self.max_stream_ordering
181 )
180182 )
181183
182184 logger.info(
203205 http_push_processed_counter.inc()
204206 self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
205207 self.last_stream_ordering = push_action["stream_ordering"]
206 pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
207 self.app_id,
208 self.pushkey,
209 self.user_id,
210 self.last_stream_ordering,
211 self.clock.time_msec(),
208 pusher_still_exists = (
209 await self.store.update_pusher_last_stream_ordering_and_success(
210 self.app_id,
211 self.pushkey,
212 self.user_id,
213 self.last_stream_ordering,
214 self.clock.time_msec(),
215 )
212216 )
213217 if not pusher_still_exists:
214218 # The pusher has been deleted while we were processing, so
289293 # for sanity, we only remove the pushkey if it
290294 # was the one we actually sent...
291295 logger.warning(
292 ("Ignoring rejected pushkey %s because we didn't send it"), pk,
296 ("Ignoring rejected pushkey %s because we didn't send it"),
297 pk,
293298 )
294299 else:
295300 logger.info("Pushkey %s was rejected: removing", pk)
3333 descriptor_from_member_events,
3434 name_from_member_event,
3535 )
36 from synapse.storage.state import StateFilter
3637 from synapse.types import StateMap, UserID
3738 from synapse.util.async_helpers import concurrently_execute
3839 from synapse.visibility import filter_events_for_client
109110
110111 self.sendmail = self.hs.get_sendmail()
111112 self.store = self.hs.get_datastore()
113 self.state_store = self.hs.get_storage().state
112114 self.macaroon_gen = self.hs.get_macaroon_generator()
113115 self.state_handler = self.hs.get_state_handler()
114116 self.storage = hs.get_storage()
216218 push_actions: Iterable[Dict[str, Any]],
217219 reason: Dict[str, Any],
218220 ) -> None:
219 """Send email regarding a user's room notifications"""
221 """
222 Send email regarding a user's room notifications
223
224 Params:
225 app_id: The application receiving the notification.
226 user_id: The user receiving the notification.
227 email_address: The email address receiving the notification.
228 push_actions: All outstanding notifications.
229 reason: The notification that was ready and is the cause of an email
230 being sent.
231 """
220232 rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
221233
222234 notif_events = await self.store.get_events(
240252 except StoreError:
241253 user_display_name = user_id
242254
243 async def _fetch_room_state(room_id):
255 async def _fetch_room_state(room_id: str) -> None:
244256 room_state = await self.store.get_current_state_ids(room_id)
245257 state_by_room[room_id] = room_state
246258
254266 rooms = []
255267
256268 for r in rooms_in_order:
257 roomvars = await self.get_room_vars(
269 roomvars = await self._get_room_vars(
258270 r, user_id, notifs_by_room[r], notif_events, state_by_room[r]
259271 )
260272 rooms.append(roomvars)
270282 # Only one room has new stuff
271283 room_id = list(notifs_by_room.keys())[0]
272284
273 summary_text = await self.make_summary_text_single_room(
285 summary_text = await self._make_summary_text_single_room(
274286 room_id,
275287 notifs_by_room[room_id],
276288 state_by_room[room_id],
278290 user_id,
279291 )
280292 else:
281 summary_text = await self.make_summary_text(
293 summary_text = await self._make_summary_text(
282294 notifs_by_room, state_by_room, notif_events, reason
283295 )
284296
285297 template_vars = {
286298 "user_display_name": user_display_name,
287 "unsubscribe_link": self.make_unsubscribe_link(
299 "unsubscribe_link": self._make_unsubscribe_link(
288300 user_id, app_id, email_address
289301 ),
290302 "summary_text": summary_text,
348360 )
349361 )
350362
351 async def get_room_vars(
363 async def _get_room_vars(
352364 self,
353365 room_id: str,
354366 user_id: str,
356368 notif_events: Dict[str, EventBase],
357369 room_state_ids: StateMap[str],
358370 ) -> Dict[str, Any]:
371 """
372 Generate the variables for notifications on a per-room basis.
373
374 Args:
375 room_id: The room ID
376 user_id: The user receiving the notification.
377 notifs: The outstanding push actions for this room.
378 notif_events: The events related to the above notifications.
379 room_state_ids: The event IDs of the current room state.
380
381 Returns:
382 A dictionary to be added to the template context.
383 """
384
359385 # Check if one of the notifs is an invite event for the user.
360386 is_invite = False
361387 for n in notifs:
372398 "hash": string_ordinal_total(room_id), # See sender avatar hash
373399 "notifs": [],
374400 "invite": is_invite,
375 "link": self.make_room_link(room_id),
401 "link": self._make_room_link(room_id),
376402 } # type: Dict[str, Any]
377403
378404 if not is_invite:
379405 for n in notifs:
380 notifvars = await self.get_notif_vars(
406 notifvars = await self._get_notif_vars(
381407 n, user_id, notif_events[n["event_id"]], room_state_ids
382408 )
383409
404430
405431 return room_vars
406432
407 async def get_notif_vars(
433 async def _get_notif_vars(
408434 self,
409435 notif: Dict[str, Any],
410436 user_id: str,
411437 notif_event: EventBase,
412438 room_state_ids: StateMap[str],
413439 ) -> Dict[str, Any]:
440 """
441 Generate the variables for a single notification.
442
443 Args:
444 notif: The outstanding notification for this room.
445 user_id: The user receiving the notification.
446 notif_event: The event related to the above notification.
447 room_state_ids: The event IDs of the current room state.
448
449 Returns:
450 A dictionary to be added to the template context.
451 """
452
414453 results = await self.store.get_events_around(
415454 notif["room_id"],
416455 notif["event_id"],
419458 )
420459
421460 ret = {
422 "link": self.make_notif_link(notif),
461 "link": self._make_notif_link(notif),
423462 "ts": notif["received_ts"],
424463 "messages": [],
425464 }
430469 the_events.append(notif_event)
431470
432471 for event in the_events:
433 messagevars = await self.get_message_vars(notif, event, room_state_ids)
472 messagevars = await self._get_message_vars(notif, event, room_state_ids)
434473 if messagevars is not None:
435474 ret["messages"].append(messagevars)
436475
437476 return ret
438477
439 async def get_message_vars(
478 async def _get_message_vars(
440479 self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
441480 ) -> Optional[Dict[str, Any]]:
481 """
482 Generate the variables for a single event, if possible.
483
484 Args:
485 notif: The outstanding notification for this room.
486 event: The event under consideration.
487 room_state_ids: The event IDs of the current room state.
488
489 Returns:
490 A dictionary to be added to the template context, or None if the
491 event cannot be processed.
492 """
442493 if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
443494 return None
444495
445 sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
446 sender_state_event = await self.store.get_event(sender_state_event_id)
447 sender_name = name_from_member_event(sender_state_event)
448 sender_avatar_url = sender_state_event.content.get("avatar_url")
496 # Get the sender's name and avatar from the room state.
497 type_state_key = ("m.room.member", event.sender)
498 sender_state_event_id = room_state_ids.get(type_state_key)
499 if sender_state_event_id:
500 sender_state_event = await self.store.get_event(
501 sender_state_event_id
502 ) # type: Optional[EventBase]
503 else:
504 # Attempt to check the historical state for the room.
505 historical_state = await self.state_store.get_state_for_event(
506 event.event_id, StateFilter.from_types((type_state_key,))
507 )
508 sender_state_event = historical_state.get(type_state_key)
509
510 if sender_state_event:
511 sender_name = name_from_member_event(sender_state_event)
512 sender_avatar_url = sender_state_event.content.get("avatar_url")
513 else:
514 # No state could be found, fallback to the MXID.
515 sender_name = event.sender
516 sender_avatar_url = None
449517
450518 # 'hash' for deterministically picking default images: use
451519 # sender_hash % the number of default images to choose from
470538 ret["msgtype"] = msgtype
471539
472540 if msgtype == "m.text":
473 self.add_text_message_vars(ret, event)
541 self._add_text_message_vars(ret, event)
474542 elif msgtype == "m.image":
475 self.add_image_message_vars(ret, event)
543 self._add_image_message_vars(ret, event)
476544
477545 if "body" in event.content:
478546 ret["body_text_plain"] = event.content["body"]
479547
480548 return ret
481549
482 def add_text_message_vars(
550 def _add_text_message_vars(
483551 self, messagevars: Dict[str, Any], event: EventBase
484552 ) -> None:
553 """
554 Potentially add a sanitised message body to the message variables.
555
556 Args:
557 messagevars: The template context to be modified.
558 event: The event under consideration.
559 """
485560 msgformat = event.content.get("format")
486561
487562 messagevars["format"] = msgformat
494569 elif body:
495570 messagevars["body_text_html"] = safe_text(body)
496571
497 def add_image_message_vars(
572 def _add_image_message_vars(
498573 self, messagevars: Dict[str, Any], event: EventBase
499574 ) -> None:
500575 """
501576 Potentially add an image URL to the message variables.
577
578 Args:
579 messagevars: The template context to be modified.
580 event: The event under consideration.
502581 """
503582 if "url" in event.content:
504583 messagevars["image_url"] = event.content["url"]
505584
506 async def make_summary_text_single_room(
585 async def _make_summary_text_single_room(
507586 self,
508587 room_id: str,
509588 notifs: List[Dict[str, Any]],
516595
517596 Args:
518597 room_id: The ID of the room.
519 notifs: The notifications for this room.
598 notifs: The push actions for this room.
520599 room_state_ids: The state map for the room.
521600 notif_events: A map of event ID -> notification event.
522601 user_id: The user receiving the notification.
599678 "app": self.app_name,
600679 }
601680
602 return await self.make_summary_text_from_member_events(
681 return await self._make_summary_text_from_member_events(
603682 room_id, notifs, room_state_ids, notif_events
604683 )
605684
606 async def make_summary_text(
685 async def _make_summary_text(
607686 self,
608687 notifs_by_room: Dict[str, List[Dict[str, Any]]],
609688 room_state_ids: Dict[str, StateMap[str]],
614693 Make a summary text for the email when multiple rooms have notifications.
615694
616695 Args:
617 notifs_by_room: A map of room ID to the notifications for that room.
696 notifs_by_room: A map of room ID to the push actions for that room.
618697 room_state_ids: A map of room ID to the state map for that room.
619698 notif_events: A map of event ID -> notification event.
620699 reason: The reason this notification is being sent.
631710 }
632711
633712 room_id = reason["room_id"]
634 return await self.make_summary_text_from_member_events(
713 return await self._make_summary_text_from_member_events(
635714 room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events
636715 )
637716
638 async def make_summary_text_from_member_events(
717 async def _make_summary_text_from_member_events(
639718 self,
640719 room_id: str,
641720 notifs: List[Dict[str, Any]],
647726
648727 Args:
649728 room_id: The ID of the room.
650 notifs: The notifications for this room.
729 notifs: The push actions for this room.
651730 room_state_ids: The state map for the room.
652731 notif_events: A map of event ID -> notification event.
653732
656735 """
657736 # If the room doesn't have a name, say who the messages
658737 # are from explicitly to avoid, "messages in the Bob room"
659 sender_ids = {notif_events[n["event_id"]].sender for n in notifs}
660
661 member_events = await self.store.get_events(
662 [room_state_ids[("m.room.member", s)] for s in sender_ids]
663 )
738
739 # Find the latest event ID for each sender, note that the notifications
740 # are already in descending received_ts.
741 sender_ids = {}
742 for n in notifs:
743 sender = notif_events[n["event_id"]].sender
744 if sender not in sender_ids:
745 sender_ids[sender] = n["event_id"]
746
747 # Get the actual member events (in order to calculate a pretty name for
748 # the room).
749 member_event_ids = []
750 member_events = {}
751 for sender_id, event_id in sender_ids.items():
752 type_state_key = ("m.room.member", sender_id)
753 sender_state_event_id = room_state_ids.get(type_state_key)
754 if sender_state_event_id:
755 member_event_ids.append(sender_state_event_id)
756 else:
757 # Attempt to check the historical state for the room.
758 historical_state = await self.state_store.get_state_for_event(
759 event_id, StateFilter.from_types((type_state_key,))
760 )
761 sender_state_event = historical_state.get(type_state_key)
762 if sender_state_event:
763 member_events[event_id] = sender_state_event
764 member_events.update(await self.store.get_events(member_event_ids))
765
766 if not member_events:
767 # No member events were found! Maybe the room is empty?
768 # Fallback to the room ID (note that if there was a room name this
769 # would already have been used previously).
770 return self.email_subjects.messages_in_room % {
771 "room": room_id,
772 "app": self.app_name,
773 }
664774
665775 # There was a single sender.
666 if len(sender_ids) == 1:
776 if len(member_events) == 1:
667777 return self.email_subjects.messages_from_person % {
668778 "person": descriptor_from_member_events(member_events.values()),
669779 "app": self.app_name,
675785 "app": self.app_name,
676786 }
677787
678 def make_room_link(self, room_id: str) -> str:
788 def _make_room_link(self, room_id: str) -> str:
789 """
790 Generate a link to open a room in the web client.
791
792 Args:
793 room_id: The room ID to generate a link to.
794
795 Returns:
796 A link to open a room in the web client.
797 """
679798 if self.hs.config.email_riot_base_url:
680799 base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
681800 elif self.app_name == "Vector":
685804 base_url = "https://matrix.to/#"
686805 return "%s/%s" % (base_url, room_id)
687806
688 def make_notif_link(self, notif: Dict[str, str]) -> str:
807 def _make_notif_link(self, notif: Dict[str, str]) -> str:
808 """
809 Generate a link to open an event in the web client.
810
811 Args:
812 notif: The notification to generate a link for.
813
814 Returns:
815 A link to open the notification in the web client.
816 """
689817 if self.hs.config.email_riot_base_url:
690818 return "%s/#/room/%s/%s" % (
691819 self.hs.config.email_riot_base_url,
701829 else:
702830 return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
703831
704 def make_unsubscribe_link(
832 def _make_unsubscribe_link(
705833 self, user_id: str, app_id: str, email_address: str
706834 ) -> str:
835 """
836 Generate a link to unsubscribe from email notifications.
837
838 Args:
839 user_id: The user receiving the notification.
840 app_id: The application receiving the notification.
841 email_address: The email address receiving the notification.
842
843 Returns:
844 A link to unsubscribe from email notifications.
845 """
707846 params = {
708847 "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
709848 "app_id": app_id,
7777 self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
7878
7979 def start(self) -> None:
80 """Starts the pushers off in a background process.
81 """
80 """Starts the pushers off in a background process."""
8281 if not self._should_start_pushers:
8382 logger.info("Not starting pushers because they are disabled in the config")
8483 return
296295 return pusher
297296
298297 async def _start_pushers(self) -> None:
299 """Start all the pushers
300 """
298 """Start all the pushers"""
301299 pushers = await self.store.get_all_pushers()
302300
303301 # Stagger starting up the pushers so we don't completely drown the
334332 return None
335333 except Exception:
336334 logger.exception(
337 "Couldn't start pusher id %i: caught Exception", pusher_config.id,
335 "Couldn't start pusher id %i: caught Exception",
336 pusher_config.id,
338337 )
339338 return None
340339
8585
8686 CONDITIONAL_REQUIREMENTS = {
8787 "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
88 # we use execute_values with the fetch param, which arrived in psycopg 2.8.
89 "postgres": ["psycopg2>=2.8"],
88 "postgres": [
89 # we use execute_values with the fetch param, which arrived in psycopg 2.8.
90 "psycopg2>=2.8 ; platform_python_implementation != 'PyPy'",
91 "psycopg2cffi>=2.8 ; platform_python_implementation == 'PyPy'",
92 "psycopg2cffi-compat==1.1 ; platform_python_implementation == 'PyPy'",
93 ],
9094 # ACME support is required to provision TLS certificates from authorities
9195 # that use the protocol, such as Let's Encrypt.
9296 "acme": [
272272 pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
273273
274274 http_server.register_paths(
275 method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
275 method,
276 [pattern],
277 self._check_auth_and_handle,
278 self.__class__.__name__,
276279 )
277280
278281 def _check_auth_and_handle(self, request, **kwargs):
174174 return {}
175175
176176 async def _handle_request(self, request, user_id, room_id, tag):
177 max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
177 max_stream_id = await self.handler.remove_tag_from_room(
178 user_id,
179 room_id,
180 tag,
181 )
178182
179183 return 200, {"max_stream_id": max_stream_id}
180184
159159
160160 # hopefully we're now on the master, so this won't recurse!
161161 event_id, stream_id = await self.member_handler.remote_reject_invite(
162 invite_event_id, txn_id, requester, event_content,
162 invite_event_id,
163 txn_id,
164 requester,
165 event_content,
163166 )
164167
165168 return 200, {"event_id": event_id, "stream_id": stream_id}
2121
2222
2323 class ReplicationRegisterServlet(ReplicationEndpoint):
24 """Register a new user
25 """
24 """Register a new user"""
2625
2726 NAME = "register_user"
2827 PATH_ARGS = ("user_id",)
9695
9796
9897 class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
99 """Run any post registration actions
100 """
98 """Run any post registration actions"""
10199
102100 NAME = "post_register"
103101 PATH_ARGS = ("user_id",)
195195
196196
197197 class PingCommand(_SimpleCommand):
198 """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
199 """
198 """Sent by either side as a keep alive. The data is arbitrary (often timestamp)"""
200199
201200 NAME = "PING"
202201
5959 return self._redis_connection is not None
6060
6161 async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
62 """Add the key/value to the named cache, with the expiry time given.
63 """
62 """Add the key/value to the named cache, with the expiry time given."""
6463
6564 if self._redis_connection is None:
6665 return
7574
7675 return await make_deferred_yieldable(
7776 self._redis_connection.set(
78 self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
77 self._get_redis_key(cache_name, key),
78 encoded_value,
79 pexpire=expiry_ms,
7980 )
8081 )
8182
8283 async def get(self, cache_name: str, key: str) -> Optional[Any]:
83 """Look up a key/value in the named cache.
84 """
84 """Look up a key/value in the named cache."""
8585
8686 if self._redis_connection is None:
8787 return None
302302 hs, outbound_redis_connection
303303 )
304304 hs.get_reactor().connectTCP(
305 hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
305 hs.config.redis.redis_host,
306 hs.config.redis.redis_port,
307 self._factory,
306308 )
307309 else:
308310 client_name = hs.get_instance_name()
312314 hs.get_reactor().connectTCP(host, port, self._factory)
313315
314316 def get_streams(self) -> Dict[str, Stream]:
315 """Get a map from stream name to all streams.
316 """
317 """Get a map from stream name to all streams."""
317318 return self._streams
318319
319320 def get_streams_to_replicate(self) -> List[Stream]:
320 """Get a list of streams that this instances replicates.
321 """
321 """Get a list of streams that this instances replicates."""
322322 return self._streams_to_replicate
323323
324324 def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
339339 current_token = stream.current_token(self._instance_name)
340340 self.send_command(
341341 PositionCommand(
342 stream.NAME, self._instance_name, current_token, current_token,
342 stream.NAME,
343 self._instance_name,
344 current_token,
345 current_token,
343346 )
344347 )
345348
591594 self.send_command(cmd, ignore_conn=conn)
592595
593596 def new_connection(self, connection: AbstractConnection):
594 """Called when we have a new connection.
595 """
597 """Called when we have a new connection."""
596598 self._connections.append(connection)
597599
598600 # If we are connected to replication as a client (rather than a server)
619621 )
620622
621623 def lost_connection(self, connection: AbstractConnection):
622 """Called when a connection is closed/lost.
623 """
624 """Called when a connection is closed/lost."""
624625 # we no longer need _streams_by_connection for this connection.
625626 streams = self._streams_by_connection.pop(connection, None)
626627 if streams:
677678 def send_user_sync(
678679 self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
679680 ):
680 """Poke the master that a user has started/stopped syncing.
681 """
681 """Poke the master that a user has started/stopped syncing."""
682682 self.send_command(
683683 UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
684684 )
685685
686686 def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
687 """Poke the master to remove a pusher for a user
688 """
687 """Poke the master to remove a pusher for a user"""
689688 cmd = RemovePusherCommand(app_id, push_key, user_id)
690689 self.send_command(cmd)
691690
698697 device_id: str,
699698 last_seen: int,
700699 ):
701 """Tell the master that the user made a request.
702 """
700 """Tell the master that the user made a request."""
703701 cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
704702 self.send_command(cmd)
705703
221221 self.send_error("ping timeout")
222222
223223 def lineReceived(self, line: bytes):
224 """Called when we've received a line
225 """
224 """Called when we've received a line"""
226225 with PreserveLoggingContext(self._logging_context):
227226 self._parse_and_dispatch_line(line)
228227
298297 self.on_connection_closed()
299298
300299 def send_error(self, error_string, *args):
301 """Send an error to remote and close the connection.
302 """
300 """Send an error to remote and close the connection."""
303301 self.send_command(ErrorCommand(error_string % args))
304302 self.close()
305303
340338 self.last_sent_command = self.clock.time_msec()
341339
342340 def _queue_command(self, cmd):
343 """Queue the command until the connection is ready to write to again.
344 """
341 """Queue the command until the connection is ready to write to again."""
345342 logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
346343 self.pending_commands.append(cmd)
347344
354351 self.close()
355352
356353 def _send_pending_commands(self):
357 """Send any queued commandes
358 """
354 """Send any queued commandes"""
359355 pending = self.pending_commands
360356 self.pending_commands = []
361357 for cmd in pending:
379375 self.state = ConnectionStates.PAUSED
380376
381377 def resumeProducing(self):
382 """The remote has caught up after we started buffering!
383 """
378 """The remote has caught up after we started buffering!"""
384379 logger.info("[%s] Resume producing", self.id())
385380 self.state = ConnectionStates.ESTABLISHED
386381 self._send_pending_commands()
439434 return "%s-%s" % (self.name, self.conn_id)
440435
441436 def lineLengthExceeded(self, line):
442 """Called when we receive a line that is above the maximum line length
443 """
437 """Called when we receive a line that is above the maximum line length"""
444438 self.send_error("Line length exceeded")
445439
446440
494488 self.send_error("Wrong remote")
495489
496490 def replicate(self):
497 """Send the subscription request to the server
498 """
491 """Send the subscription request to the server"""
499492 logger.info("[%s] Subscribing to replication streams", self.id())
500493
501494 self.send_command(ReplicateCommand())
502495
503496
504497 class AbstractConnection(abc.ABC):
505 """An interface for replication connections.
506 """
498 """An interface for replication connections."""
507499
508500 @abc.abstractmethod
509501 def send_command(self, cmd: Command):
510 """Send the command down the connection
511 """
502 """Send the command down the connection"""
512503 pass
513504
514505
1414
1515 import logging
1616 from inspect import isawaitable
17 from typing import TYPE_CHECKING, Optional, Type, cast
18
17 from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
18
19 import attr
1920 import txredisapi
2021
2122 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
4142
4243 logger = logging.getLogger(__name__)
4344
45 T = TypeVar("T")
46 V = TypeVar("V")
47
48
49 @attr.s
50 class ConstantProperty(Generic[T, V]):
51 """A descriptor that returns the given constant, ignoring attempts to set
52 it.
53 """
54
55 constant = attr.ib() # type: V
56
57 def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V:
58 return self.constant
59
60 def __set__(self, obj: Optional[T], value: V):
61 pass
62
4463
4564 class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
4665 """Connection to redis subscribed to replication stream.
103122 self.synapse_handler.send_positions_to_connection(self)
104123
105124 def messageReceived(self, pattern: str, channel: str, message: str):
106 """Received a message from redis.
107 """
125 """Received a message from redis."""
108126 with PreserveLoggingContext(self._logging_context):
109127 self._parse_and_dispatch_message(message)
110128
117135 cmd = parse_command_from_line(message)
118136 except Exception:
119137 logger.exception(
120 "Failed to parse replication line: %r", message,
138 "Failed to parse replication line: %r",
139 message,
121140 )
122141 return
123142
193212 """A subclass of RedisFactory that periodically sends pings to ensure that
194213 we detect dead connections.
195214 """
215
216 # We want to *always* retry connecting, txredisapi will stop if there is a
217 # failure during certain operations, e.g. during AUTH.
218 continueTrying = cast(bool, ConstantProperty(True))
196219
197220 def __init__(
198221 self,
242265 """
243266
244267 maxDelay = 5
245 continueTrying = True
246268 protocol = RedisSubscriber
247269
248270 def __init__(
3535
3636
3737 class ReplicationStreamProtocolFactory(Factory):
38 """Factory for new replication connections.
39 """
38 """Factory for new replication connections."""
4039
4140 def __init__(self, hs):
4241 self.command_handler = hs.get_tcp_replication()
180179 raise
181180
182181 logger.debug(
183 "Sending %d updates", len(updates),
182 "Sending %d updates",
183 len(updates),
184184 )
185185
186186 if updates:
182182 return [], upto_token, False
183183
184184 updates, upto_token, limited = await self.update_function(
185 instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
185 instance_name,
186 from_token,
187 upto_token,
188 _STREAM_UPDATE_TARGET_ROW_COUNT,
186189 )
187190 return updates, upto_token, limited
188191
338341
339342
340343 class PushRulesStream(Stream):
341 """A user has changed their push rules
342 """
344 """A user has changed their push rules"""
343345
344346 PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
345347
361363
362364
363365 class PushersStream(Stream):
364 """A user has added/changed/removed a pusher
365 """
366 """A user has added/changed/removed a pusher"""
366367
367368 PushersStreamRow = namedtuple(
368369 "PushersStreamRow",
415416
416417
417418 class PublicRoomsStream(Stream):
418 """The public rooms list changed
419 """
419 """The public rooms list changed"""
420420
421421 PublicRoomsStreamRow = namedtuple(
422422 "PublicRoomsStreamRow",
462462
463463
464464 class ToDeviceStream(Stream):
465 """New to_device messages for a client
466 """
465 """New to_device messages for a client"""
467466
468467 ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
469468
480479
481480
482481 class TagAccountDataStream(Stream):
483 """Someone added/removed a tag for a room
484 """
482 """Someone added/removed a tag for a room"""
485483
486484 TagAccountDataStreamRow = namedtuple(
487485 "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
500498
501499
502500 class AccountDataStream(Stream):
503 """Global or per room account data was changed
504 """
501 """Global or per room account data was changed"""
505502
506503 AccountDataStreamRow = namedtuple(
507504 "AccountDataStream",
588585
589586
590587 class UserSignatureStream(Stream):
591 """A user has signed their own device with their user-signing key
592 """
588 """A user has signed their own device with their user-signing key"""
593589
594590 UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
595591
112112
113113
114114 class EventsStream(Stream):
115 """We received a new event, or an event went from being an outlier to not
116 """
115 """We received a new event, or an event went from being an outlier to not"""
117116
118117 NAME = "events"
119118
0 body {
0 body, input, select, textarea {
11 font-family: "Inter", "Helvetica", "Arial", sans-serif;
22 font-size: 14px;
33 color: #17191C;
44 }
55
6 header {
6 header, footer {
77 max-width: 480px;
88 width: 100%;
99 margin: 24px auto;
1010 text-align: center;
11 }
12
13 @media screen and (min-width: 800px) {
14 header {
15 margin-top: 90px;
16 }
17 }
18
19 header {
20 min-height: 60px;
1121 }
1222
1323 header p {
1727
1828 h1 {
1929 font-size: 24px;
30 }
31
32 a {
33 color: #418DED;
2034 }
2135
2236 .error_page h1 {
4660
4761 .primary-button {
4862 border: none;
63 -webkit-appearance: none;
64 -moz-appearance: none;
65 appearance: none;
4966 text-decoration: none;
5067 padding: 12px;
5168 color: white;
6279
6380 .profile {
6481 display: flex;
82 flex-direction: column;
83 align-items: center;
6584 justify-content: center;
66 margin: 24px 0;
85 margin: 24px;
86 padding: 13px;
87 border: 1px solid #E9ECF1;
88 border-radius: 4px;
89 }
90
91 .profile.with-avatar {
92 margin-top: 42px; /* (36px / 2) + 24px*/
6793 }
6894
6995 .profile .avatar {
7197 height: 36px;
7298 border-radius: 100%;
7399 display: block;
74 margin-right: 8px;
100 margin-top: -32px;
101 margin-bottom: 8px;
75102 }
76103
77104 .profile .display-name {
78105 font-weight: bold;
79106 margin-bottom: 4px;
107 font-size: 15px;
108 line-height: 18px;
80109 }
81110 .profile .user-id {
82111 color: #737D8C;
112 font-size: 12px;
113 line-height: 12px;
83114 }
84115
85 .profile .display-name, .profile .user-id {
86 line-height: 18px;
116 footer {
117 margin-top: 80px;
87118 }
119
120 footer svg {
121 display: block;
122 width: 46px;
123 margin: 0px auto 12px auto;
124 }
125
126 footer p {
127 color: #737D8C;
128 }
1919 administrator.
2020 </p>
2121 </header>
22 {% include "sso_footer.html" without context %}
2223 </body>
2324 </html>
00 <!DOCTYPE html>
11 <html lang="en">
22 <head>
3 <title>Synapse Login</title>
3 <title>Create your account</title>
44 <meta charset="utf-8">
55 <meta name="viewport" content="width=device-width, user-scalable=no">
6 <script type="text/javascript">
7 let wasKeyboard = false;
8 document.addEventListener("mousedown", function() { wasKeyboard = false; });
9 document.addEventListener("keydown", function() { wasKeyboard = true; });
10 document.addEventListener("focusin", function() {
11 if (wasKeyboard) {
12 document.body.classList.add("keyboard-focus");
13 } else {
14 document.body.classList.remove("keyboard-focus");
15 }
16 });
17 </script>
618 <style type="text/css">
719 {% include "sso.css" without context %}
20
21 body.keyboard-focus :focus, body.keyboard-focus .username_input:focus-within {
22 outline: 3px solid #17191C;
23 outline-offset: 4px;
24 }
825
926 .username_input {
1027 display: flex;
3249
3350 .username_input label {
3451 position: absolute;
35 top: -8px;
52 top: -5px;
3653 left: 14px;
37 font-size: 80%;
54 font-size: 10px;
55 line-height: 10px;
3856 background: white;
39 padding: 2px;
57 padding: 0 2px;
4058 }
4159
4260 .username_input input {
4462 display: block;
4563 min-width: 0;
4664 border: none;
65 }
66
67 /* only clear the outline if we know it will be shown on the parent div using :focus-within */
68 @supports selector(:focus-within) {
69 .username_input input {
70 outline: none !important;
71 }
4772 }
4873
4974 .username_input div {
6489 .idp-pick-details .idp-detail {
6590 border-top: 1px solid #E9ECF1;
6691 padding: 12px;
92 display: block;
6793 }
6894 .idp-pick-details .check-row {
6995 display: flex;
116142 </div>
117143 <output for="username_input" id="field-username-output"></output>
118144 <input type="submit" value="Continue" class="primary-button">
119 {% if user_attributes %}
145 {% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %}
120146 <section class="idp-pick-details">
121147 <h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2>
122148 {% if user_attributes.avatar_url %}
123 <div class="idp-detail idp-avatar">
149 <label class="idp-detail idp-avatar" for="idp-avatar">
124150 <div class="check-row">
125 <label for="idp-avatar" class="name">Avatar</label>
126 <label for="idp-avatar" class="use">Use</label>
151 <span class="name">Avatar</span>
152 <span class="use">Use</span>
127153 <input type="checkbox" name="use_avatar" id="idp-avatar" value="true" checked>
128154 </div>
129155 <img src="{{ user_attributes.avatar_url }}" class="avatar" />
130 </div>
156 </label>
131157 {% endif %}
132158 {% if user_attributes.display_name %}
133 <div class="idp-detail">
159 <label class="idp-detail" for="idp-displayname">
134160 <div class="check-row">
135 <label for="idp-displayname" class="name">Display name</label>
136 <label for="idp-displayname" class="use">Use</label>
161 <span class="name">Display name</span>
162 <span class="use">Use</span>
137163 <input type="checkbox" name="use_display_name" id="idp-displayname" value="true" checked>
138164 </div>
139165 <p class="idp-value">{{ user_attributes.display_name }}</p>
140 </div>
166 </label>
141167 {% endif %}
142168 {% for email in user_attributes.emails %}
143 <div class="idp-detail">
169 <label class="idp-detail" for="idp-email{{ loop.index }}">
144170 <div class="check-row">
145 <label for="idp-email{{ loop.index }}" class="name">E-mail</label>
146 <label for="idp-email{{ loop.index }}" class="use">Use</label>
171 <span class="name">E-mail</span>
172 <span class="use">Use</span>
147173 <input type="checkbox" name="use_email" id="idp-email{{ loop.index }}" value="{{ email }}" checked>
148174 </div>
149175 <p class="idp-value">{{ email }}</p>
150 </div>
176 </label>
151177 {% endfor %}
152178 </section>
153179 {% endif %}
154180 </form>
155181 </main>
182 {% include "sso_footer.html" without context %}
156183 <script type="text/javascript">
157184 {% include "sso_auth_account_details.js" without context %}
158185 </script>
2020 the Identity Provider as when you log into your account.
2121 </p>
2222 </header>
23 {% include "sso_footer.html" without context %}
2324 </body>
2425 </html>
11 <html lang="en">
22 <head>
33 <meta charset="UTF-8">
4 <title>Authentication</title>
4 <title>Confirm it's you</title>
55 <meta name="viewport" content="width=device-width, user-scalable=no">
66 <style type="text/css">
77 {% include "sso.css" without context %}
2323 Continue with {{ idp.idp_name }}
2424 </a>
2525 </main>
26 {% include "sso_footer.html" without context %}
2627 </body>
2728 </html>
2222 application.
2323 </p>
2424 </header>
25 {% include "sso_footer.html" without context %}
2526 </body>
2627 </html>
3737 <p>{{ error }}</p>
3838 </div>
3939 </header>
40 {% include "sso_footer.html" without context %}
4041
4142 <script type="text/javascript">
4243 // Error handling to support Auth0 errors that we might get through a GET request
0 <footer>
1 <svg role="img" aria-label="[Matrix logo]" viewBox="0 0 200 85" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
2 <g id="parent" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
3 <g id="child" transform="translate(-122.000000, -6.000000)" fill="#000000" fill-rule="nonzero">
4 <g id="matrix-logo" transform="translate(122.000000, 6.000000)">
5 <polygon id="left-bracket" points="2.24708861 1.93811009 2.24708861 82.7268844 8.10278481 82.7268844 8.10278481 84.6652459 0 84.6652459 0 0 8.10278481 0 8.10278481 1.93811009"></polygon>
6 <path d="M24.8073418,27.5493174 L24.8073418,31.6376991 L24.924557,31.6376991 C26.0227848,30.0814294 27.3455696,28.8730642 28.8951899,28.0163743 C30.4437975,27.1611927 32.2189873,26.7318422 34.218481,26.7318422 C36.1394937,26.7318422 37.8946835,27.102622 39.4825316,27.8416679 C41.0708861,28.5819706 42.276962,29.8856073 43.1005063,31.7548404 C44.0017722,30.431345 45.2270886,29.2629486 46.7767089,28.2506569 C48.3253165,27.2388679 50.158481,26.7318422 52.2764557,26.7318422 C53.8843038,26.7318422 55.3736709,26.9269101 56.7473418,27.3162917 C58.1189873,27.7056734 59.295443,28.3285835 60.2759494,29.185022 C61.255443,30.0422147 62.02,31.1615927 62.5701266,32.5426532 C63.1187342,33.9262275 63.3936709,35.5898349 63.3936709,37.5372459 L63.3936709,57.7443688 L55.0410127,57.7441174 L55.0410127,40.6319376 C55.0410127,39.6201486 55.0020253,38.6661761 54.9232911,37.7700202 C54.8440506,36.8751211 54.6293671,36.0968606 54.2764557,35.4339817 C53.9232911,34.772611 53.403038,34.2464807 52.7177215,33.8568477 C52.0313924,33.4689743 51.0997468,33.2731523 49.9235443,33.2731523 C48.7473418,33.2731523 47.7962025,33.4983853 47.0706329,33.944578 C46.344557,34.393033 45.7764557,34.9774826 45.3650633,35.6969211 C44.9534177,36.4181193 44.6787342,37.2353431 44.5417722,38.150855 C44.4037975,39.0653615 44.3356962,39.9904257 44.3356962,40.9247908 L44.3356962,57.7443688 L35.9835443,57.7443688 L35.9835443,40.8079009 C35.9835443,39.9124991 35.963038,39.0263982 35.9253165,38.150855 C35.8853165,37.2743064 35.7192405,36.4666349 35.424557,35.7263321 C35.1303797,34.9872862 34.64,34.393033 33.9539241,33.944578 C33.2675949,33.4983853 32.2579747,33.2731523 30.9248101,33.2731523 C30.5321519,33.2731523 30.0126582,33.3608826 29.3663291,33.5365945 C28.7192405,33.7118037 28.0913924,34.0433688 27.4840506,34.5292789 C26.875443,35.0164459 26.3564557,35.7172826 25.9250633,36.6315376 C25.4934177,37.5470495 25.2779747,38.7436 25.2779747,40.2229486 L25.2779747,57.7441174 L16.9260759,57.7443688 L16.9260759,27.5493174 L24.8073418,27.5493174 Z" id="m"></path>
7 <path d="M68.7455696,31.9886202 C69.6075949,30.7033339 70.7060759,29.672189 72.0397468,28.8926716 C73.3724051,28.1141596 74.8716456,27.5596239 76.5387342,27.2283101 C78.2050633,26.8977505 79.8817722,26.7315908 81.5678481,26.7315908 C83.0974684,26.7315908 84.6458228,26.8391798 86.2144304,27.0525982 C87.7827848,27.2675248 89.2144304,27.6865688 90.5086076,28.3087248 C91.8025316,28.9313835 92.8610127,29.7983798 93.6848101,30.9074514 C94.5083544,32.0170257 94.92,33.4870734 94.92,35.3173431 L94.92,51.026844 C94.92,52.3913138 94.998481,53.6941963 95.1556962,54.9400165 C95.3113924,56.1865908 95.5863291,57.120956 95.9787342,57.7436147 L87.5091139,57.7436147 C87.3518987,57.276055 87.2240506,56.7996972 87.1265823,56.3125303 C87.0278481,55.8266202 86.9592405,55.3301523 86.9207595,54.8236294 C85.5873418,56.1865908 84.0182278,57.1405633 82.2156962,57.6857982 C80.4113924,58.2295248 78.5683544,58.503022 76.6860759,58.503022 C75.2346835,58.503022 73.8817722,58.3275615 72.6270886,57.9776459 C71.3718987,57.6269761 70.2744304,57.082244 69.3334177,56.3411872 C68.3921519,55.602644 67.656962,54.6680275 67.1275949,53.5390972 C66.5982278,52.410167 66.3331646,51.065556 66.3331646,49.5087835 C66.3331646,47.7961578 66.6367089,46.384178 67.2455696,45.2756092 C67.8529114,44.1652807 68.6367089,43.2799339 69.5987342,42.6173064 C70.5589873,41.9556844 71.6567089,41.4592165 72.8924051,41.1284055 C74.1273418,40.7978459 75.3721519,40.5356606 76.6270886,40.3398385 C77.8820253,40.1457761 79.116962,39.9896716 80.3329114,39.873033 C81.5483544,39.7558917 82.6270886,39.5804312 83.5681013,39.3469028 C84.5093671,39.1133743 85.2536709,38.7732624 85.8032911,38.3250587 C86.3513924,37.8773578 86.6063291,37.2252881 86.5678481,36.3680954 C86.5678481,35.4731963 86.4210127,34.7620532 86.1268354,34.2366771 C85.8329114,33.7113009 85.4405063,33.3018092 84.9506329,33.0099615 C84.4602532,32.7181138 83.8916456,32.5232972 83.2450633,32.4255119 C82.5977215,32.3294862 81.9010127,32.2797138 81.156962,32.2797138 C79.5098734,32.2797138 78.2159494,32.6303835 77.2746835,33.3312202 C76.3339241,34.0320569 75.7837975,35.2007046 75.6275949,36.8354037 L67.275443,36.8354037 C67.3924051,34.8892495 67.8817722,33.2726495 68.7455696,31.9886202 Z M85.2440506,43.6984752 C84.7149367,43.873433 84.1460759,44.0189798 83.5387342,44.1361211 C82.9306329,44.253011 82.2936709,44.350545 81.6270886,44.4279688 C80.96,44.5066495 80.2934177,44.6034294 79.6273418,44.7203193 C78.9994937,44.8362037 78.3820253,44.9933138 77.7749367,45.1871248 C77.1663291,45.3829468 76.636962,45.6451321 76.1865823,45.9759431 C75.7349367,46.3070055 75.3724051,46.7263009 75.0979747,47.2313156 C74.8232911,47.7375872 74.6863291,48.380356 74.6863291,49.1588679 C74.6863291,49.8979138 74.8232911,50.5218294 75.0979747,51.026844 C75.3724051,51.5338697 75.7455696,51.9328037 76.2159494,52.2246514 C76.6863291,52.5164991 77.2349367,52.7213706 77.8632911,52.8375064 C78.4898734,52.9546477 79.136962,53.012967 79.8037975,53.012967 C81.4506329,53.012967 82.724557,52.740978 83.6273418,52.1952404 C84.5288608,51.6507596 85.1949367,50.9981872 85.6270886,50.2382771 C86.0579747,49.4793725 86.323038,48.7119211 86.4212658,47.9321523 C86.518481,47.1536404 86.5681013,46.5304789 86.5681013,46.063422 L86.5681013,42.9677248 C86.2146835,43.2799339 85.7736709,43.5230147 85.2440506,43.6984752 Z" id="a"></path>
8 <path d="M116.917975,27.5493174 L116.917975,33.0976917 L110.801266,33.0976917 L110.801266,48.0492936 C110.801266,49.4502128 111.036203,50.3850807 111.507089,50.8518862 C111.976962,51.3191945 112.918734,51.5527229 114.33038,51.5527229 C114.801013,51.5527229 115.251392,51.5336183 115.683038,51.4944037 C116.114177,51.4561945 116.526076,51.3968697 116.917975,51.3194459 L116.917975,57.7438661 C116.212152,57.860756 115.427595,57.9381798 114.565316,57.9778972 C113.702785,58.0153523 112.859747,58.0357138 112.036203,58.0357138 C110.742278,58.0357138 109.516456,57.9477321 108.36,57.7722716 C107.202785,57.5975651 106.183544,57.2577046 105.301519,56.7509303 C104.418987,56.2454128 103.722785,55.5242147 103.213418,54.5898495 C102.703038,53.6562385 102.448608,52.4292716 102.448608,50.9099541 L102.448608,33.0976917 L97.3903797,33.0976917 L97.3903797,27.5493174 L102.448608,27.5493174 L102.448608,18.4967596 L110.801013,18.4967596 L110.801013,27.5493174 L116.917975,27.5493174 Z" id="t"></path>
9 <path d="M128.857975,27.5493174 L128.857975,33.1565138 L128.975696,33.1565138 C129.367089,32.2213945 129.896203,31.3559064 130.563544,30.557033 C131.23038,29.7596679 131.99443,29.0776844 132.857215,28.5130936 C133.719241,27.9495083 134.641266,27.5113596 135.622532,27.1988991 C136.601772,26.8879468 137.622025,26.7315908 138.681013,26.7315908 C139.229873,26.7315908 139.836962,26.8296275 140.504304,27.0239413 L140.504304,34.7336477 C140.111646,34.6552183 139.641013,34.586844 139.092658,34.5290275 C138.543291,34.4704569 138.014177,34.4410459 137.504304,34.4410459 C135.974937,34.4410459 134.681013,34.6949358 133.622785,35.2004532 C132.564051,35.7067248 131.711392,36.397255 131.064051,37.2735523 C130.417215,38.1501009 129.955443,39.1714422 129.681266,40.3398385 C129.407089,41.5074807 129.269873,42.7736624 129.269873,44.1361211 L129.269873,57.7438661 L120.917722,57.7438661 L120.917722,27.5493174 L128.857975,27.5493174 Z" id="r"></path>
10 <path d="M144.033165,22.8767376 L144.033165,16.0435798 L152.386076,16.0435798 L152.386076,22.8767376 L144.033165,22.8767376 Z M152.386076,27.5493174 L152.386076,57.7438661 L144.033165,57.7438661 L144.033165,27.5493174 L152.386076,27.5493174 Z" id="i"></path>
11 <polygon id="x" points="156.738228 27.5493174 166.266582 27.5493174 171.619494 35.4337303 176.913418 27.5493174 186.147848 27.5493174 176.148861 41.6831927 187.383544 57.7441174 177.85443 57.7441174 171.501772 48.2245028 165.148861 57.7441174 155.797468 57.7441174 166.737468 41.8589046"></polygon>
12 <polygon id="right-bracket" points="197.580759 82.7268844 197.580759 1.93811009 191.725063 1.93811009 191.725063 0 199.828354 0 199.828354 84.6652459 191.725063 84.6652459 191.725063 82.7268844"></polygon>
13 </g>
14 </g>
15 </g>
16 </svg>
17 <p>An open network for secure, decentralized communication.<br>© 2021 The Matrix.org Foundation C.I.C.</p>
18 </footer>
11 <html lang="en">
22 <head>
33 <meta charset="UTF-8">
4 <link rel="stylesheet" href="/_matrix/static/client/login/style.css">
5 <title>{{ server_name }} Login</title>
4 <title>Choose identity provider</title>
5 <style type="text/css">
6 {% include "sso.css" without context %}
7
8 .providers {
9 list-style: none;
10 padding: 0;
11 }
12
13 .providers li {
14 margin: 12px;
15 }
16
17 .providers a {
18 display: block;
19 border-radius: 4px;
20 border: 1px solid #17191C;
21 padding: 8px;
22 text-align: center;
23 text-decoration: none;
24 color: #17191C;
25 display: flex;
26 align-items: center;
27 font-weight: bold;
28 }
29
30 .providers a img {
31 width: 24px;
32 height: 24px;
33 }
34 .providers a span {
35 flex: 1;
36 }
37 </style>
638 </head>
739 <body>
8 <div id="container">
9 <h1 id="title">{{ server_name }} Login</h1>
10 <div class="login_flow">
11 <p>Choose one of the following identity providers:</p>
12 <form>
13 <input type="hidden" name="redirectUrl" value="{{ redirect_url }}">
14 <ul class="radiobuttons">
15 {% for p in providers %}
16 <li>
17 <input type="radio" name="idp" id="prov{{ loop.index }}" value="{{ p.idp_id }}">
18 <label for="prov{{ loop.index }}">{{ p.idp_name }}</label>
19 {% if p.idp_icon %}
40 <header>
41 <h1>Log in to {{ server_name }} </h1>
42 <p>Choose an identity provider to log in</p>
43 </header>
44 <main>
45 <ul class="providers">
46 {% for p in providers %}
47 <li>
48 <a href="pick_idp?idp={{ p.idp_id }}&redirectUrl={{ redirect_url | urlencode }}">
49 {% if p.idp_icon %}
2050 <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/>
21 {% endif %}
22 </li>
23 {% endfor %}
24 </ul>
25 <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
26 </form>
27 </div>
28 </div>
51 {% endif %}
52 <span>{{ p.idp_name }}</span>
53 </a>
54 </li>
55 {% endfor %}
56 </ul>
57 </main>
58 {% include "sso_footer.html" without context %}
2959 </body>
3060 </html>
11 <html lang="en">
22 <head>
33 <meta charset="UTF-8">
4 <title>SSO redirect confirmation</title>
4 <title>Agree to terms and conditions</title>
55 <meta name="viewport" content="width=device-width, user-scalable=no">
66 <style type="text/css">
77 {% include "sso.css" without context %}
1717 <p>Agree to the terms to create your account.</p>
1818 </header>
1919 <main>
20 <!-- {% if user_profile.avatar_url and user_profile.display_name %} -->
21 <div class="profile">
22 <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
23 <div class="profile-details">
24 <div class="display-name">{{ user_profile.display_name }}</div>
25 <div class="user-id">{{ user_id }}</div>
26 </div>
27 </div>
28 <!-- {% endif %} -->
20 {% include "sso_partial_profile.html" %}
2921 <form method="post" action="{{my_url}}" id="consent_form">
3022 <p>
3123 <input id="accepted_version" type="checkbox" name="accepted_version" value="{{ consent_version }}" required>
32 <label for="accepted_version">I have read and agree to the <a href="{{ terms_url }}" target="_blank">terms and conditions</a>.</label>
24 <label for="accepted_version">I have read and agree to the <a href="{{ terms_url }}" target="_blank" rel="noopener">terms and conditions</a>.</label>
3325 </p>
3426 <input type="submit" class="primary-button" value="Continue"/>
3527 </form>
3628 </main>
29 {% include "sso_footer.html" without context %}
3730 </body>
3831 </html>
0 {# html fragment to be included in SSO pages, to show the user's profile #}
1
2 <div class="profile{% if user_profile.avatar_url %} with-avatar{% endif %}">
3 {% if user_profile.avatar_url %}
4 <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
5 {% endif %}
6 {# users that signed up with SSO will have a display_name of some sort;
7 however that is not the case for users who signed up via other
8 methods, so we need to handle that.
9 #}
10 {% if user_profile.display_name %}
11 <div class="display-name">{{ user_profile.display_name }}</div>
12 {% else %}
13 {# split the userid on ':', take the part before the first ':',
14 and then remove the leading '@'. #}
15 <div class="display-name">{{ user_id.split(":")[0][1:] }}</div>
16 {% endif %}
17 <div class="user-id">{{ user_id }}</div>
18 </div>
11 <html lang="en">
22 <head>
33 <meta charset="UTF-8">
4 <title>SSO redirect confirmation</title>
4 <title>Continue to your account</title>
55 <meta name="viewport" content="width=device-width, user-scalable=no">
66 <style type="text/css">
77 {% include "sso.css" without context %}
8
9 .confirm-trust {
10 margin: 34px 0;
11 color: #8D99A5;
12 }
13 .confirm-trust strong {
14 color: #17191C;
15 }
16
17 .confirm-trust::before {
18 content: "";
19 background-image: url('data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMTgiIGhlaWdodD0iMTgiIHZpZXdCb3g9IjAgMCAxOCAxOCIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZmlsbC1ydWxlPSJldmVub2RkIiBjbGlwLXJ1bGU9ImV2ZW5vZGQiIGQ9Ik0xNi41IDlDMTYuNSAxMy4xNDIxIDEzLjE0MjEgMTYuNSA5IDE2LjVDNC44NTc4NiAxNi41IDEuNSAxMy4xNDIxIDEuNSA5QzEuNSA0Ljg1Nzg2IDQuODU3ODYgMS41IDkgMS41QzEzLjE0MjEgMS41IDE2LjUgNC44NTc4NiAxNi41IDlaTTcuMjUgOUM3LjI1IDkuNDY1OTYgNy41Njg2OSA5Ljg1NzQ4IDggOS45Njg1VjEyLjM3NUM4IDEyLjkyNzMgOC40NDc3MiAxMy4zNzUgOSAxMy4zNzVIMTAuMTI1QzEwLjY3NzMgMTMuMzc1IDExLjEyNSAxMi45MjczIDExLjEyNSAxMi4zNzVDMTEuMTI1IDExLjgyMjcgMTAuNjc3MyAxMS4zNzUgMTAuMTI1IDExLjM3NUgxMFY5QzEwIDguOTY1NDggOS45OTgyNSA4LjkzMTM3IDkuOTk0ODQgOC44OTc3NkM5Ljk0MzYzIDguMzkzNSA5LjUxNzc3IDggOSA4SDguMjVDNy42OTc3MiA4IDcuMjUgOC40NDc3MiA3LjI1IDlaTTkgNy41QzkuNjIxMzIgNy41IDEwLjEyNSA2Ljk5NjMyIDEwLjEyNSA2LjM3NUMxMC4xMjUgNS43NTM2OCA5LjYyMTMyIDUuMjUgOSA1LjI1QzguMzc4NjggNS4yNSA3Ljg3NSA1Ljc1MzY4IDcuODc1IDYuMzc1QzcuODc1IDYuOTk2MzIgOC4zNzg2OCA3LjUgOSA3LjVaIiBmaWxsPSIjQzFDNkNEIi8+Cjwvc3ZnPgoK');
20 background-repeat: no-repeat;
21 width: 24px;
22 height: 24px;
23 display: block;
24 float: left;
25 }
826 </style>
927 </head>
1028 <body>
1129 <header>
12 {% if new_user %}
13 <h1>Your account is now ready</h1>
14 <p>You've made your account on {{ server_name }}.</p>
15 {% else %}
16 <h1>Log in</h1>
17 {% endif %}
18 <p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p>
30 <h1>Continue to your account</h1>
1931 </header>
2032 <main>
21 {% if user_profile.avatar_url %}
22 <div class="profile">
23 <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
24 <div class="profile-details">
25 {% if user_profile.display_name %}
26 <div class="display-name">{{ user_profile.display_name }}</div>
27 {% endif %}
28 <div class="user-id">{{ user_id }}</div>
29 </div>
30 </div>
31 {% endif %}
33 {% include "sso_partial_profile.html" %}
34 <p class="confirm-trust">Continuing will grant <strong>{{ display_url }}</strong> access to your account.</p>
3235 <a href="{{ redirect_url }}" class="primary-button">Continue</a>
3336 </main>
37 {% include "sso_footer.html" without context %}
3438 </body>
3539 </html>
4141 JoinRoomAliasServlet,
4242 ListRoomRestServlet,
4343 MakeRoomAdminRestServlet,
44 RoomEventContextServlet,
4445 RoomMembersRestServlet,
4546 RoomRestServlet,
4647 RoomStateRestServlet,
237238 MakeRoomAdminRestServlet(hs).register(http_server)
238239 ShadowBanRestServlet(hs).register(http_server)
239240 ForwardExtremitiesRestServlet(hs).register(http_server)
241 RoomEventContextServlet(hs).register(http_server)
240242
241243
242244 def register_servlets_for_client_rest_resource(hs, http_server):
2121
2222
2323 class DeleteGroupAdminRestServlet(RestServlet):
24 """Allows deleting of local groups
25 """
24 """Allows deleting of local groups"""
2625
2726 PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
2827
118118
119119
120120 class ProtectMediaByID(RestServlet):
121 """Protect local media from being quarantined.
122 """
121 """Protect local media from being quarantined."""
123122
124123 PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
125124
140139
141140
142141 class ListMediaInRoom(RestServlet):
143 """Lists all of the media in a given room.
144 """
142 """Lists all of the media in a given room."""
145143
146144 PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
147145
179177
180178
181179 class DeleteMediaByID(RestServlet):
182 """Delete local media by a given ID. Removes it from this server.
183 """
180 """Delete local media by a given ID. Removes it from this server."""
184181
185182 PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
186183
1414 import logging
1515 from http import HTTPStatus
1616 from typing import TYPE_CHECKING, List, Optional, Tuple
17 from urllib import parse as urlparse
1718
1819 from synapse.api.constants import EventTypes, JoinRules, Membership
1920 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
21 from synapse.api.filtering import Filter
2022 from synapse.http.servlet import (
2123 RestServlet,
2224 assert_params_in_dict,
3234 )
3335 from synapse.storage.databases.main.room import RoomSortOrder
3436 from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
37 from synapse.util import json_decoder
3538
3639 if TYPE_CHECKING:
3740 from synapse.server import HomeServer
478481
479482 if not admin_user_id:
480483 raise SynapseError(
481 400, "No local admin user in room",
484 400,
485 "No local admin user in room",
482486 )
483487
484488 pl_content = power_levels.content
488492 admin_user_id = create_event.sender
489493 if not self.is_mine_id(admin_user_id):
490494 raise SynapseError(
491 400, "No local admin user in room",
495 400,
496 "No local admin user in room",
492497 )
493498
494499 # Grant the user power equal to the room admin by attempting to send an
498503 new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
499504
500505 fake_requester = create_requester(
501 admin_user_id, authenticated_entity=requester.authenticated_entity,
506 admin_user_id,
507 authenticated_entity=requester.authenticated_entity,
502508 )
503509
504510 try:
604610
605611 extremities = await self.store.get_forward_extremities_for_room(room_id)
606612 return 200, {"count": len(extremities), "results": extremities}
613
614
615 class RoomEventContextServlet(RestServlet):
616 """
617 Provide the context for an event.
618 This API is designed to be used when system administrators wish to look at
619 an abuse report and understand what happened during and immediately prior
620 to this event.
621 """
622
623 PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
624
625 def __init__(self, hs):
626 super().__init__()
627 self.clock = hs.get_clock()
628 self.room_context_handler = hs.get_room_context_handler()
629 self._event_serializer = hs.get_event_client_serializer()
630 self.auth = hs.get_auth()
631
632 async def on_GET(self, request, room_id, event_id):
633 requester = await self.auth.get_user_by_req(request, allow_guest=False)
634 await assert_user_is_admin(self.auth, requester.user)
635
636 limit = parse_integer(request, "limit", default=10)
637
638 # picking the API shape for symmetry with /messages
639 filter_str = parse_string(request, b"filter", encoding="utf-8")
640 if filter_str:
641 filter_json = urlparse.unquote(filter_str)
642 event_filter = Filter(
643 json_decoder.decode(filter_json)
644 ) # type: Optional[Filter]
645 else:
646 event_filter = None
647
648 results = await self.room_context_handler.get_event_context(
649 requester,
650 room_id,
651 event_id,
652 limit,
653 event_filter,
654 use_admin_priviledge=True,
655 )
656
657 if not results:
658 raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
659
660 time_now = self.clock.time_msec()
661 results["events_before"] = await self._event_serializer.serialize_events(
662 results["events_before"], time_now
663 )
664 results["event"] = await self._event_serializer.serialize_event(
665 results["event"], time_now
666 )
667 results["events_after"] = await self._event_serializer.serialize_events(
668 results["events_after"], time_now
669 )
670 results["state"] = await self._event_serializer.serialize_events(
671 results["state"], time_now
672 )
673
674 return 200, results
578578 }
579579 Returns:
580580 200 OK with empty object if success otherwise an error.
581 """
581 """
582582
583583 PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
584584
751751
752752 Returns:
753753 pushers: Dictionary containing pushers information.
754 total: Number of pushers in dictonary `pushers`.
754 total: Number of pushers in dictionary `pushers`.
755755 """
756756
757757 PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
309309 except jwt.PyJWTError as e:
310310 # A JWT error occurred, return some info back to the client.
311311 raise LoginError(
312 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
312 403,
313 "JWT validation failed: %s" % (str(e),),
314 errcode=Codes.FORBIDDEN,
313315 )
314316
315317 user = payload.get("sub", None)
374376 request, "redirectUrl", required=True, encoding=None
375377 )
376378 sso_url = await self._sso_handler.handle_redirect_request(
377 request, client_redirect_url, idp_id,
379 request,
380 client_redirect_url,
381 idp_id,
378382 )
379383 logger.info("Redirecting to %s", sso_url)
380384 request.redirect(sso_url)
5959 new_name = content["displayname"]
6060 except Exception:
6161 raise SynapseError(
62 code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON,
62 code=400,
63 msg="Unable to parse name",
64 errcode=Codes.BAD_JSON,
6365 )
6466
6567 await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
158158 self.notifier.on_new_replication_data()
159159
160160 respond_with_html_bytes(
161 request, 200, PushersRemoveRestServlet.SUCCESS_HTML,
161 request,
162 200,
163 PushersRemoveRestServlet.SUCCESS_HTML,
162164 )
163165 return None
164166
361361 parse_and_validate_server_name(server)
362362 except ValueError:
363363 raise SynapseError(
364 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
364 400,
365 "Invalid server name: %s" % (server,),
366 Codes.INVALID_PARAM,
365367 )
366368
367369 try:
412414 parse_and_validate_server_name(server)
413415 except ValueError:
414416 raise SynapseError(
415 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
417 400,
418 "Invalid server name: %s" % (server,),
419 Codes.INVALID_PARAM,
416420 )
417421
418422 try:
649653 event_filter = None
650654
651655 results = await self.room_context_handler.get_event_context(
652 requester.user, room_id, event_id, limit, event_filter
656 requester, room_id, event_id, limit, event_filter
653657 )
654658
655659 if not results:
192192 requester = await self.auth.get_user_by_req(request)
193193 try:
194194 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
195 requester, request, body, "modify your account password",
195 requester,
196 request,
197 body,
198 "modify your account password",
196199 )
197200 except InteractiveAuthIncompleteError as e:
198201 # The user needs to provide more steps to complete auth, but
311314 return 200, {}
312315
313316 await self.auth_handler.validate_user_via_ui_auth(
314 requester, request, body, "deactivate your account",
317 requester,
318 request,
319 body,
320 "deactivate your account",
315321 )
316322 result = await self._deactivate_account_handler.deactivate_account(
317323 requester.user.to_string(),
702708 assert_valid_client_secret(client_secret)
703709
704710 await self.auth_handler.validate_user_via_ui_auth(
705 requester, request, body, "add a third-party identifier to your account",
711 requester,
712 request,
713 body,
714 "add a third-party identifier to your account",
706715 )
707716
708717 validation_session = await self.identity_handler.validate_threepid_session(
8282 assert_params_in_dict(body, ["devices"])
8383
8484 await self.auth_handler.validate_user_via_ui_auth(
85 requester, request, body, "remove device(s) from your account",
85 requester,
86 request,
87 body,
88 "remove device(s) from your account",
8689 )
8790
8891 await self.device_handler.delete_devices(
128131 raise
129132
130133 await self.auth_handler.validate_user_via_ui_auth(
131 requester, request, body, "remove a device from your account",
134 requester,
135 request,
136 body,
137 "remove a device from your account",
132138 )
133139
134140 await self.device_handler.delete_device(requester.user.to_string(), device_id)
205211
206212 if "device_data" not in submission:
207213 raise errors.SynapseError(
208 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM,
214 400,
215 "device_data missing",
216 errcode=errors.Codes.MISSING_PARAM,
209217 )
210218 elif not isinstance(submission["device_data"], dict):
211219 raise errors.SynapseError(
258266
259267 if "device_id" not in submission:
260268 raise errors.SynapseError(
261 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM,
269 400,
270 "device_id missing",
271 errcode=errors.Codes.MISSING_PARAM,
262272 )
263273 elif not isinstance(submission["device_id"], str):
264274 raise errors.SynapseError(
265 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM,
275 400,
276 "device_id must be a string",
277 errcode=errors.Codes.INVALID_PARAM,
266278 )
267279
268280 result = await self.device_handler.rehydrate_device(
1515
1616 import logging
1717 from functools import wraps
18
19 from synapse.api.errors import SynapseError
20 from synapse.http.servlet import RestServlet, parse_json_object_from_request
21 from synapse.types import GroupID
18 from typing import TYPE_CHECKING, Optional, Tuple
19
20 from twisted.web.http import Request
21
22 from synapse.api.constants import (
23 MAX_GROUP_CATEGORYID_LENGTH,
24 MAX_GROUP_ROLEID_LENGTH,
25 MAX_GROUPID_LENGTH,
26 )
27 from synapse.api.errors import Codes, SynapseError
28 from synapse.handlers.groups_local import GroupsLocalHandler
29 from synapse.http.servlet import (
30 RestServlet,
31 assert_params_in_dict,
32 parse_json_object_from_request,
33 )
34 from synapse.types import GroupID, JsonDict
2235
2336 from ._base import client_patterns
37
38 if TYPE_CHECKING:
39 from synapse.app.homeserver import HomeServer
2440
2541 logger = logging.getLogger(__name__)
2642
3248 """
3349
3450 @wraps(f)
35 def wrapper(self, request, group_id, *args, **kwargs):
51 def wrapper(self, request: Request, group_id: str, *args, **kwargs):
3652 if not GroupID.is_valid(group_id):
3753 raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
3854
4258
4359
4460 class GroupServlet(RestServlet):
45 """Get the group profile
46 """
61 """Get the group profile"""
4762
4863 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
4964
50 def __init__(self, hs):
51 super().__init__()
52 self.auth = hs.get_auth()
53 self.clock = hs.get_clock()
54 self.groups_handler = hs.get_groups_local_handler()
55
56 @_validate_group_id
57 async def on_GET(self, request, group_id):
65 def __init__(self, hs: "HomeServer"):
66 super().__init__()
67 self.auth = hs.get_auth()
68 self.clock = hs.get_clock()
69 self.groups_handler = hs.get_groups_local_handler()
70
71 @_validate_group_id
72 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
5873 requester = await self.auth.get_user_by_req(request, allow_guest=True)
5974 requester_user_id = requester.user.to_string()
6075
6580 return 200, group_description
6681
6782 @_validate_group_id
68 async def on_POST(self, request, group_id):
69 requester = await self.auth.get_user_by_req(request)
70 requester_user_id = requester.user.to_string()
71
72 content = parse_json_object_from_request(request)
83 async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
84 requester = await self.auth.get_user_by_req(request)
85 requester_user_id = requester.user.to_string()
86
87 content = parse_json_object_from_request(request)
88 assert_params_in_dict(
89 content, ("name", "avatar_url", "short_description", "long_description")
90 )
91 assert isinstance(
92 self.groups_handler, GroupsLocalHandler
93 ), "Workers cannot create group profiles."
7394 await self.groups_handler.update_group_profile(
7495 group_id, requester_user_id, content
7596 )
7899
79100
80101 class GroupSummaryServlet(RestServlet):
81 """Get the full group summary
82 """
102 """Get the full group summary"""
83103
84104 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
85105
86 def __init__(self, hs):
87 super().__init__()
88 self.auth = hs.get_auth()
89 self.clock = hs.get_clock()
90 self.groups_handler = hs.get_groups_local_handler()
91
92 @_validate_group_id
93 async def on_GET(self, request, group_id):
106 def __init__(self, hs: "HomeServer"):
107 super().__init__()
108 self.auth = hs.get_auth()
109 self.clock = hs.get_clock()
110 self.groups_handler = hs.get_groups_local_handler()
111
112 @_validate_group_id
113 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
94114 requester = await self.auth.get_user_by_req(request, allow_guest=True)
95115 requester_user_id = requester.user.to_string()
96116
115135 "/rooms/(?P<room_id>[^/]*)$"
116136 )
117137
118 def __init__(self, hs):
119 super().__init__()
120 self.auth = hs.get_auth()
121 self.clock = hs.get_clock()
122 self.groups_handler = hs.get_groups_local_handler()
123
124 @_validate_group_id
125 async def on_PUT(self, request, group_id, category_id, room_id):
126 requester = await self.auth.get_user_by_req(request)
127 requester_user_id = requester.user.to_string()
128
129 content = parse_json_object_from_request(request)
138 def __init__(self, hs: "HomeServer"):
139 super().__init__()
140 self.auth = hs.get_auth()
141 self.clock = hs.get_clock()
142 self.groups_handler = hs.get_groups_local_handler()
143
144 @_validate_group_id
145 async def on_PUT(
146 self, request: Request, group_id: str, category_id: Optional[str], room_id: str
147 ):
148 requester = await self.auth.get_user_by_req(request)
149 requester_user_id = requester.user.to_string()
150
151 if category_id == "":
152 raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
153
154 if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
155 raise SynapseError(
156 400,
157 "category_id may not be longer than %s characters"
158 % (MAX_GROUP_CATEGORYID_LENGTH,),
159 Codes.INVALID_PARAM,
160 )
161
162 content = parse_json_object_from_request(request)
163 assert isinstance(
164 self.groups_handler, GroupsLocalHandler
165 ), "Workers cannot modify group summaries."
130166 resp = await self.groups_handler.update_group_summary_room(
131167 group_id,
132168 requester_user_id,
138174 return 200, resp
139175
140176 @_validate_group_id
141 async def on_DELETE(self, request, group_id, category_id, room_id):
142 requester = await self.auth.get_user_by_req(request)
143 requester_user_id = requester.user.to_string()
144
177 async def on_DELETE(
178 self, request: Request, group_id: str, category_id: str, room_id: str
179 ):
180 requester = await self.auth.get_user_by_req(request)
181 requester_user_id = requester.user.to_string()
182
183 assert isinstance(
184 self.groups_handler, GroupsLocalHandler
185 ), "Workers cannot modify group profiles."
145186 resp = await self.groups_handler.delete_group_summary_room(
146187 group_id, requester_user_id, room_id=room_id, category_id=category_id
147188 )
150191
151192
152193 class GroupCategoryServlet(RestServlet):
153 """Get/add/update/delete a group category
154 """
194 """Get/add/update/delete a group category"""
155195
156196 PATTERNS = client_patterns(
157197 "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
158198 )
159199
160 def __init__(self, hs):
161 super().__init__()
162 self.auth = hs.get_auth()
163 self.clock = hs.get_clock()
164 self.groups_handler = hs.get_groups_local_handler()
165
166 @_validate_group_id
167 async def on_GET(self, request, group_id, category_id):
200 def __init__(self, hs: "HomeServer"):
201 super().__init__()
202 self.auth = hs.get_auth()
203 self.clock = hs.get_clock()
204 self.groups_handler = hs.get_groups_local_handler()
205
206 @_validate_group_id
207 async def on_GET(
208 self, request: Request, group_id: str, category_id: str
209 ) -> Tuple[int, JsonDict]:
168210 requester = await self.auth.get_user_by_req(request, allow_guest=True)
169211 requester_user_id = requester.user.to_string()
170212
175217 return 200, category
176218
177219 @_validate_group_id
178 async def on_PUT(self, request, group_id, category_id):
179 requester = await self.auth.get_user_by_req(request)
180 requester_user_id = requester.user.to_string()
181
182 content = parse_json_object_from_request(request)
220 async def on_PUT(
221 self, request: Request, group_id: str, category_id: str
222 ) -> Tuple[int, JsonDict]:
223 requester = await self.auth.get_user_by_req(request)
224 requester_user_id = requester.user.to_string()
225
226 if not category_id:
227 raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
228
229 if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
230 raise SynapseError(
231 400,
232 "category_id may not be longer than %s characters"
233 % (MAX_GROUP_CATEGORYID_LENGTH,),
234 Codes.INVALID_PARAM,
235 )
236
237 content = parse_json_object_from_request(request)
238 assert isinstance(
239 self.groups_handler, GroupsLocalHandler
240 ), "Workers cannot modify group categories."
183241 resp = await self.groups_handler.update_group_category(
184242 group_id, requester_user_id, category_id=category_id, content=content
185243 )
187245 return 200, resp
188246
189247 @_validate_group_id
190 async def on_DELETE(self, request, group_id, category_id):
191 requester = await self.auth.get_user_by_req(request)
192 requester_user_id = requester.user.to_string()
193
248 async def on_DELETE(
249 self, request: Request, group_id: str, category_id: str
250 ) -> Tuple[int, JsonDict]:
251 requester = await self.auth.get_user_by_req(request)
252 requester_user_id = requester.user.to_string()
253
254 assert isinstance(
255 self.groups_handler, GroupsLocalHandler
256 ), "Workers cannot modify group categories."
194257 resp = await self.groups_handler.delete_group_category(
195258 group_id, requester_user_id, category_id=category_id
196259 )
199262
200263
201264 class GroupCategoriesServlet(RestServlet):
202 """Get all group categories
203 """
265 """Get all group categories"""
204266
205267 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
206268
207 def __init__(self, hs):
208 super().__init__()
209 self.auth = hs.get_auth()
210 self.clock = hs.get_clock()
211 self.groups_handler = hs.get_groups_local_handler()
212
213 @_validate_group_id
214 async def on_GET(self, request, group_id):
269 def __init__(self, hs: "HomeServer"):
270 super().__init__()
271 self.auth = hs.get_auth()
272 self.clock = hs.get_clock()
273 self.groups_handler = hs.get_groups_local_handler()
274
275 @_validate_group_id
276 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
215277 requester = await self.auth.get_user_by_req(request, allow_guest=True)
216278 requester_user_id = requester.user.to_string()
217279
223285
224286
225287 class GroupRoleServlet(RestServlet):
226 """Get/add/update/delete a group role
227 """
288 """Get/add/update/delete a group role"""
228289
229290 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
230291
231 def __init__(self, hs):
232 super().__init__()
233 self.auth = hs.get_auth()
234 self.clock = hs.get_clock()
235 self.groups_handler = hs.get_groups_local_handler()
236
237 @_validate_group_id
238 async def on_GET(self, request, group_id, role_id):
292 def __init__(self, hs: "HomeServer"):
293 super().__init__()
294 self.auth = hs.get_auth()
295 self.clock = hs.get_clock()
296 self.groups_handler = hs.get_groups_local_handler()
297
298 @_validate_group_id
299 async def on_GET(
300 self, request: Request, group_id: str, role_id: str
301 ) -> Tuple[int, JsonDict]:
239302 requester = await self.auth.get_user_by_req(request, allow_guest=True)
240303 requester_user_id = requester.user.to_string()
241304
246309 return 200, category
247310
248311 @_validate_group_id
249 async def on_PUT(self, request, group_id, role_id):
250 requester = await self.auth.get_user_by_req(request)
251 requester_user_id = requester.user.to_string()
252
253 content = parse_json_object_from_request(request)
312 async def on_PUT(
313 self, request: Request, group_id: str, role_id: str
314 ) -> Tuple[int, JsonDict]:
315 requester = await self.auth.get_user_by_req(request)
316 requester_user_id = requester.user.to_string()
317
318 if not role_id:
319 raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
320
321 if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
322 raise SynapseError(
323 400,
324 "role_id may not be longer than %s characters"
325 % (MAX_GROUP_ROLEID_LENGTH,),
326 Codes.INVALID_PARAM,
327 )
328
329 content = parse_json_object_from_request(request)
330 assert isinstance(
331 self.groups_handler, GroupsLocalHandler
332 ), "Workers cannot modify group roles."
254333 resp = await self.groups_handler.update_group_role(
255334 group_id, requester_user_id, role_id=role_id, content=content
256335 )
258337 return 200, resp
259338
260339 @_validate_group_id
261 async def on_DELETE(self, request, group_id, role_id):
262 requester = await self.auth.get_user_by_req(request)
263 requester_user_id = requester.user.to_string()
264
340 async def on_DELETE(
341 self, request: Request, group_id: str, role_id: str
342 ) -> Tuple[int, JsonDict]:
343 requester = await self.auth.get_user_by_req(request)
344 requester_user_id = requester.user.to_string()
345
346 assert isinstance(
347 self.groups_handler, GroupsLocalHandler
348 ), "Workers cannot modify group roles."
265349 resp = await self.groups_handler.delete_group_role(
266350 group_id, requester_user_id, role_id=role_id
267351 )
270354
271355
272356 class GroupRolesServlet(RestServlet):
273 """Get all group roles
274 """
357 """Get all group roles"""
275358
276359 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
277360
278 def __init__(self, hs):
279 super().__init__()
280 self.auth = hs.get_auth()
281 self.clock = hs.get_clock()
282 self.groups_handler = hs.get_groups_local_handler()
283
284 @_validate_group_id
285 async def on_GET(self, request, group_id):
361 def __init__(self, hs: "HomeServer"):
362 super().__init__()
363 self.auth = hs.get_auth()
364 self.clock = hs.get_clock()
365 self.groups_handler = hs.get_groups_local_handler()
366
367 @_validate_group_id
368 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
286369 requester = await self.auth.get_user_by_req(request, allow_guest=True)
287370 requester_user_id = requester.user.to_string()
288371
307390 "/users/(?P<user_id>[^/]*)$"
308391 )
309392
310 def __init__(self, hs):
311 super().__init__()
312 self.auth = hs.get_auth()
313 self.clock = hs.get_clock()
314 self.groups_handler = hs.get_groups_local_handler()
315
316 @_validate_group_id
317 async def on_PUT(self, request, group_id, role_id, user_id):
318 requester = await self.auth.get_user_by_req(request)
319 requester_user_id = requester.user.to_string()
320
321 content = parse_json_object_from_request(request)
393 def __init__(self, hs: "HomeServer"):
394 super().__init__()
395 self.auth = hs.get_auth()
396 self.clock = hs.get_clock()
397 self.groups_handler = hs.get_groups_local_handler()
398
399 @_validate_group_id
400 async def on_PUT(
401 self, request: Request, group_id: str, role_id: Optional[str], user_id: str
402 ) -> Tuple[int, JsonDict]:
403 requester = await self.auth.get_user_by_req(request)
404 requester_user_id = requester.user.to_string()
405
406 if role_id == "":
407 raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
408
409 if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH:
410 raise SynapseError(
411 400,
412 "role_id may not be longer than %s characters"
413 % (MAX_GROUP_ROLEID_LENGTH,),
414 Codes.INVALID_PARAM,
415 )
416
417 content = parse_json_object_from_request(request)
418 assert isinstance(
419 self.groups_handler, GroupsLocalHandler
420 ), "Workers cannot modify group summaries."
322421 resp = await self.groups_handler.update_group_summary_user(
323422 group_id,
324423 requester_user_id,
330429 return 200, resp
331430
332431 @_validate_group_id
333 async def on_DELETE(self, request, group_id, role_id, user_id):
334 requester = await self.auth.get_user_by_req(request)
335 requester_user_id = requester.user.to_string()
336
432 async def on_DELETE(
433 self, request: Request, group_id: str, role_id: str, user_id: str
434 ):
435 requester = await self.auth.get_user_by_req(request)
436 requester_user_id = requester.user.to_string()
437
438 assert isinstance(
439 self.groups_handler, GroupsLocalHandler
440 ), "Workers cannot modify group summaries."
337441 resp = await self.groups_handler.delete_group_summary_user(
338442 group_id, requester_user_id, user_id=user_id, role_id=role_id
339443 )
342446
343447
344448 class GroupRoomServlet(RestServlet):
345 """Get all rooms in a group
346 """
449 """Get all rooms in a group"""
347450
348451 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
349452
350 def __init__(self, hs):
351 super().__init__()
352 self.auth = hs.get_auth()
353 self.clock = hs.get_clock()
354 self.groups_handler = hs.get_groups_local_handler()
355
356 @_validate_group_id
357 async def on_GET(self, request, group_id):
453 def __init__(self, hs: "HomeServer"):
454 super().__init__()
455 self.auth = hs.get_auth()
456 self.clock = hs.get_clock()
457 self.groups_handler = hs.get_groups_local_handler()
458
459 @_validate_group_id
460 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
358461 requester = await self.auth.get_user_by_req(request, allow_guest=True)
359462 requester_user_id = requester.user.to_string()
360463
366469
367470
368471 class GroupUsersServlet(RestServlet):
369 """Get all users in a group
370 """
472 """Get all users in a group"""
371473
372474 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
373475
374 def __init__(self, hs):
375 super().__init__()
376 self.auth = hs.get_auth()
377 self.clock = hs.get_clock()
378 self.groups_handler = hs.get_groups_local_handler()
379
380 @_validate_group_id
381 async def on_GET(self, request, group_id):
476 def __init__(self, hs: "HomeServer"):
477 super().__init__()
478 self.auth = hs.get_auth()
479 self.clock = hs.get_clock()
480 self.groups_handler = hs.get_groups_local_handler()
481
482 @_validate_group_id
483 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
382484 requester = await self.auth.get_user_by_req(request, allow_guest=True)
383485 requester_user_id = requester.user.to_string()
384486
390492
391493
392494 class GroupInvitedUsersServlet(RestServlet):
393 """Get users invited to a group
394 """
495 """Get users invited to a group"""
395496
396497 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
397498
398 def __init__(self, hs):
399 super().__init__()
400 self.auth = hs.get_auth()
401 self.clock = hs.get_clock()
402 self.groups_handler = hs.get_groups_local_handler()
403
404 @_validate_group_id
405 async def on_GET(self, request, group_id):
499 def __init__(self, hs: "HomeServer"):
500 super().__init__()
501 self.auth = hs.get_auth()
502 self.clock = hs.get_clock()
503 self.groups_handler = hs.get_groups_local_handler()
504
505 @_validate_group_id
506 async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
406507 requester = await self.auth.get_user_by_req(request)
407508 requester_user_id = requester.user.to_string()
408509
414515
415516
416517 class GroupSettingJoinPolicyServlet(RestServlet):
417 """Set group join policy
418 """
518 """Set group join policy"""
419519
420520 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
421521
422 def __init__(self, hs):
423 super().__init__()
424 self.auth = hs.get_auth()
425 self.groups_handler = hs.get_groups_local_handler()
426
427 @_validate_group_id
428 async def on_PUT(self, request, group_id):
429 requester = await self.auth.get_user_by_req(request)
430 requester_user_id = requester.user.to_string()
431
432 content = parse_json_object_from_request(request)
433
522 def __init__(self, hs: "HomeServer"):
523 super().__init__()
524 self.auth = hs.get_auth()
525 self.groups_handler = hs.get_groups_local_handler()
526
527 @_validate_group_id
528 async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
529 requester = await self.auth.get_user_by_req(request)
530 requester_user_id = requester.user.to_string()
531
532 content = parse_json_object_from_request(request)
533
534 assert isinstance(
535 self.groups_handler, GroupsLocalHandler
536 ), "Workers cannot modify group join policy."
434537 result = await self.groups_handler.set_group_join_policy(
435538 group_id, requester_user_id, content
436539 )
439542
440543
441544 class GroupCreateServlet(RestServlet):
442 """Create a group
443 """
545 """Create a group"""
444546
445547 PATTERNS = client_patterns("/create_group$")
446548
447 def __init__(self, hs):
549 def __init__(self, hs: "HomeServer"):
448550 super().__init__()
449551 self.auth = hs.get_auth()
450552 self.clock = hs.get_clock()
451553 self.groups_handler = hs.get_groups_local_handler()
452554 self.server_name = hs.hostname
453555
454 async def on_POST(self, request):
556 async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
455557 requester = await self.auth.get_user_by_req(request)
456558 requester_user_id = requester.user.to_string()
457559
460562 localpart = content.pop("localpart")
461563 group_id = GroupID(localpart, self.server_name).to_string()
462564
565 if not localpart:
566 raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
567
568 if len(group_id) > MAX_GROUPID_LENGTH:
569 raise SynapseError(
570 400,
571 "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,),
572 Codes.INVALID_PARAM,
573 )
574
575 assert isinstance(
576 self.groups_handler, GroupsLocalHandler
577 ), "Workers cannot create groups."
463578 result = await self.groups_handler.create_group(
464579 group_id, requester_user_id, content
465580 )
468583
469584
470585 class GroupAdminRoomsServlet(RestServlet):
471 """Add a room to the group
472 """
586 """Add a room to the group"""
473587
474588 PATTERNS = client_patterns(
475589 "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
476590 )
477591
478 def __init__(self, hs):
479 super().__init__()
480 self.auth = hs.get_auth()
481 self.clock = hs.get_clock()
482 self.groups_handler = hs.get_groups_local_handler()
483
484 @_validate_group_id
485 async def on_PUT(self, request, group_id, room_id):
486 requester = await self.auth.get_user_by_req(request)
487 requester_user_id = requester.user.to_string()
488
489 content = parse_json_object_from_request(request)
592 def __init__(self, hs: "HomeServer"):
593 super().__init__()
594 self.auth = hs.get_auth()
595 self.clock = hs.get_clock()
596 self.groups_handler = hs.get_groups_local_handler()
597
598 @_validate_group_id
599 async def on_PUT(
600 self, request: Request, group_id: str, room_id: str
601 ) -> Tuple[int, JsonDict]:
602 requester = await self.auth.get_user_by_req(request)
603 requester_user_id = requester.user.to_string()
604
605 content = parse_json_object_from_request(request)
606 assert isinstance(
607 self.groups_handler, GroupsLocalHandler
608 ), "Workers cannot modify rooms in a group."
490609 result = await self.groups_handler.add_room_to_group(
491610 group_id, requester_user_id, room_id, content
492611 )
494613 return 200, result
495614
496615 @_validate_group_id
497 async def on_DELETE(self, request, group_id, room_id):
498 requester = await self.auth.get_user_by_req(request)
499 requester_user_id = requester.user.to_string()
500
616 async def on_DELETE(
617 self, request: Request, group_id: str, room_id: str
618 ) -> Tuple[int, JsonDict]:
619 requester = await self.auth.get_user_by_req(request)
620 requester_user_id = requester.user.to_string()
621
622 assert isinstance(
623 self.groups_handler, GroupsLocalHandler
624 ), "Workers cannot modify group categories."
501625 result = await self.groups_handler.remove_room_from_group(
502626 group_id, requester_user_id, room_id
503627 )
506630
507631
508632 class GroupAdminRoomsConfigServlet(RestServlet):
509 """Update the config of a room in a group
510 """
633 """Update the config of a room in a group"""
511634
512635 PATTERNS = client_patterns(
513636 "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
514637 "/config/(?P<config_key>[^/]*)$"
515638 )
516639
517 def __init__(self, hs):
518 super().__init__()
519 self.auth = hs.get_auth()
520 self.clock = hs.get_clock()
521 self.groups_handler = hs.get_groups_local_handler()
522
523 @_validate_group_id
524 async def on_PUT(self, request, group_id, room_id, config_key):
525 requester = await self.auth.get_user_by_req(request)
526 requester_user_id = requester.user.to_string()
527
528 content = parse_json_object_from_request(request)
640 def __init__(self, hs: "HomeServer"):
641 super().__init__()
642 self.auth = hs.get_auth()
643 self.clock = hs.get_clock()
644 self.groups_handler = hs.get_groups_local_handler()
645
646 @_validate_group_id
647 async def on_PUT(
648 self, request: Request, group_id: str, room_id: str, config_key: str
649 ):
650 requester = await self.auth.get_user_by_req(request)
651 requester_user_id = requester.user.to_string()
652
653 content = parse_json_object_from_request(request)
654 assert isinstance(
655 self.groups_handler, GroupsLocalHandler
656 ), "Workers cannot modify group categories."
529657 result = await self.groups_handler.update_room_in_group(
530658 group_id, requester_user_id, room_id, config_key, content
531659 )
534662
535663
536664 class GroupAdminUsersInviteServlet(RestServlet):
537 """Invite a user to the group
538 """
665 """Invite a user to the group"""
539666
540667 PATTERNS = client_patterns(
541668 "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
542669 )
543670
544 def __init__(self, hs):
671 def __init__(self, hs: "HomeServer"):
545672 super().__init__()
546673 self.auth = hs.get_auth()
547674 self.clock = hs.get_clock()
550677 self.is_mine_id = hs.is_mine_id
551678
552679 @_validate_group_id
553 async def on_PUT(self, request, group_id, user_id):
680 async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
554681 requester = await self.auth.get_user_by_req(request)
555682 requester_user_id = requester.user.to_string()
556683
557684 content = parse_json_object_from_request(request)
558685 config = content.get("config", {})
686 assert isinstance(
687 self.groups_handler, GroupsLocalHandler
688 ), "Workers cannot invite users to a group."
559689 result = await self.groups_handler.invite(
560690 group_id, user_id, requester_user_id, config
561691 )
564694
565695
566696 class GroupAdminUsersKickServlet(RestServlet):
567 """Kick a user from the group
568 """
697 """Kick a user from the group"""
569698
570699 PATTERNS = client_patterns(
571700 "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
572701 )
573702
574 def __init__(self, hs):
575 super().__init__()
576 self.auth = hs.get_auth()
577 self.clock = hs.get_clock()
578 self.groups_handler = hs.get_groups_local_handler()
579
580 @_validate_group_id
581 async def on_PUT(self, request, group_id, user_id):
582 requester = await self.auth.get_user_by_req(request)
583 requester_user_id = requester.user.to_string()
584
585 content = parse_json_object_from_request(request)
703 def __init__(self, hs: "HomeServer"):
704 super().__init__()
705 self.auth = hs.get_auth()
706 self.clock = hs.get_clock()
707 self.groups_handler = hs.get_groups_local_handler()
708
709 @_validate_group_id
710 async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
711 requester = await self.auth.get_user_by_req(request)
712 requester_user_id = requester.user.to_string()
713
714 content = parse_json_object_from_request(request)
715 assert isinstance(
716 self.groups_handler, GroupsLocalHandler
717 ), "Workers cannot kick users from a group."
586718 result = await self.groups_handler.remove_user_from_group(
587719 group_id, user_id, requester_user_id, content
588720 )
591723
592724
593725 class GroupSelfLeaveServlet(RestServlet):
594 """Leave a joined group
595 """
726 """Leave a joined group"""
596727
597728 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
598729
599 def __init__(self, hs):
600 super().__init__()
601 self.auth = hs.get_auth()
602 self.clock = hs.get_clock()
603 self.groups_handler = hs.get_groups_local_handler()
604
605 @_validate_group_id
606 async def on_PUT(self, request, group_id):
607 requester = await self.auth.get_user_by_req(request)
608 requester_user_id = requester.user.to_string()
609
610 content = parse_json_object_from_request(request)
730 def __init__(self, hs: "HomeServer"):
731 super().__init__()
732 self.auth = hs.get_auth()
733 self.clock = hs.get_clock()
734 self.groups_handler = hs.get_groups_local_handler()
735
736 @_validate_group_id
737 async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
738 requester = await self.auth.get_user_by_req(request)
739 requester_user_id = requester.user.to_string()
740
741 content = parse_json_object_from_request(request)
742 assert isinstance(
743 self.groups_handler, GroupsLocalHandler
744 ), "Workers cannot leave a group for a users."
611745 result = await self.groups_handler.remove_user_from_group(
612746 group_id, requester_user_id, requester_user_id, content
613747 )
616750
617751
618752 class GroupSelfJoinServlet(RestServlet):
619 """Attempt to join a group, or knock
620 """
753 """Attempt to join a group, or knock"""
621754
622755 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
623756
624 def __init__(self, hs):
625 super().__init__()
626 self.auth = hs.get_auth()
627 self.clock = hs.get_clock()
628 self.groups_handler = hs.get_groups_local_handler()
629
630 @_validate_group_id
631 async def on_PUT(self, request, group_id):
632 requester = await self.auth.get_user_by_req(request)
633 requester_user_id = requester.user.to_string()
634
635 content = parse_json_object_from_request(request)
757 def __init__(self, hs: "HomeServer"):
758 super().__init__()
759 self.auth = hs.get_auth()
760 self.clock = hs.get_clock()
761 self.groups_handler = hs.get_groups_local_handler()
762
763 @_validate_group_id
764 async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
765 requester = await self.auth.get_user_by_req(request)
766 requester_user_id = requester.user.to_string()
767
768 content = parse_json_object_from_request(request)
769 assert isinstance(
770 self.groups_handler, GroupsLocalHandler
771 ), "Workers cannot join a user to a group."
636772 result = await self.groups_handler.join_group(
637773 group_id, requester_user_id, content
638774 )
641777
642778
643779 class GroupSelfAcceptInviteServlet(RestServlet):
644 """Accept a group invite
645 """
780 """Accept a group invite"""
646781
647782 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
648783
649 def __init__(self, hs):
650 super().__init__()
651 self.auth = hs.get_auth()
652 self.clock = hs.get_clock()
653 self.groups_handler = hs.get_groups_local_handler()
654
655 @_validate_group_id
656 async def on_PUT(self, request, group_id):
657 requester = await self.auth.get_user_by_req(request)
658 requester_user_id = requester.user.to_string()
659
660 content = parse_json_object_from_request(request)
784 def __init__(self, hs: "HomeServer"):
785 super().__init__()
786 self.auth = hs.get_auth()
787 self.clock = hs.get_clock()
788 self.groups_handler = hs.get_groups_local_handler()
789
790 @_validate_group_id
791 async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
792 requester = await self.auth.get_user_by_req(request)
793 requester_user_id = requester.user.to_string()
794
795 content = parse_json_object_from_request(request)
796 assert isinstance(
797 self.groups_handler, GroupsLocalHandler
798 ), "Workers cannot accept an invite to a group."
661799 result = await self.groups_handler.accept_invite(
662800 group_id, requester_user_id, content
663801 )
666804
667805
668806 class GroupSelfUpdatePublicityServlet(RestServlet):
669 """Update whether we publicise a users membership of a group
670 """
807 """Update whether we publicise a users membership of a group"""
671808
672809 PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
673810
674 def __init__(self, hs):
811 def __init__(self, hs: "HomeServer"):
675812 super().__init__()
676813 self.auth = hs.get_auth()
677814 self.clock = hs.get_clock()
678815 self.store = hs.get_datastore()
679816
680817 @_validate_group_id
681 async def on_PUT(self, request, group_id):
818 async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
682819 requester = await self.auth.get_user_by_req(request)
683820 requester_user_id = requester.user.to_string()
684821
690827
691828
692829 class PublicisedGroupsForUserServlet(RestServlet):
693 """Get the list of groups a user is advertising
694 """
830 """Get the list of groups a user is advertising"""
695831
696832 PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
697833
698 def __init__(self, hs):
834 def __init__(self, hs: "HomeServer"):
699835 super().__init__()
700836 self.auth = hs.get_auth()
701837 self.clock = hs.get_clock()
702838 self.store = hs.get_datastore()
703839 self.groups_handler = hs.get_groups_local_handler()
704840
705 async def on_GET(self, request, user_id):
841 async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
706842 await self.auth.get_user_by_req(request, allow_guest=True)
707843
708844 result = await self.groups_handler.get_publicised_groups_for_user(user_id)
711847
712848
713849 class PublicisedGroupsForUsersServlet(RestServlet):
714 """Get the list of groups a user is advertising
715 """
850 """Get the list of groups a user is advertising"""
716851
717852 PATTERNS = client_patterns("/publicised_groups$")
718853
719 def __init__(self, hs):
854 def __init__(self, hs: "HomeServer"):
720855 super().__init__()
721856 self.auth = hs.get_auth()
722857 self.clock = hs.get_clock()
723858 self.store = hs.get_datastore()
724859 self.groups_handler = hs.get_groups_local_handler()
725860
726 async def on_POST(self, request):
861 async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
727862 await self.auth.get_user_by_req(request, allow_guest=True)
728863
729864 content = parse_json_object_from_request(request)
735870
736871
737872 class GroupsForUserServlet(RestServlet):
738 """Get all groups the logged in user is joined to
739 """
873 """Get all groups the logged in user is joined to"""
740874
741875 PATTERNS = client_patterns("/joined_groups$")
742876
743 def __init__(self, hs):
744 super().__init__()
745 self.auth = hs.get_auth()
746 self.clock = hs.get_clock()
747 self.groups_handler = hs.get_groups_local_handler()
748
749 async def on_GET(self, request):
877 def __init__(self, hs: "HomeServer"):
878 super().__init__()
879 self.auth = hs.get_auth()
880 self.clock = hs.get_clock()
881 self.groups_handler = hs.get_groups_local_handler()
882
883 async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
750884 requester = await self.auth.get_user_by_req(request, allow_guest=True)
751885 requester_user_id = requester.user.to_string()
752886
755889 return 200, result
756890
757891
758 def register_servlets(hs, http_server):
892 def register_servlets(hs: "HomeServer", http_server):
759893 GroupServlet(hs).register(http_server)
760894 GroupSummaryServlet(hs).register(http_server)
761895 GroupInvitedUsersServlet(hs).register(http_server)
270270 body = parse_json_object_from_request(request)
271271
272272 await self.auth_handler.validate_user_via_ui_auth(
273 requester, request, body, "add a device signing key to your account",
273 requester,
274 request,
275 body,
276 "add a device signing key to your account",
274277 )
275278
276279 result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
192192 body, ["client_secret", "country", "phone_number", "send_attempt"]
193193 )
194194 client_secret = body["client_secret"]
195 assert_valid_client_secret(client_secret)
195196 country = body["country"]
196197 phone_number = body["phone_number"]
197198 send_attempt = body["send_attempt"]
292293
293294 sid = parse_string(request, "sid", required=True)
294295 client_secret = parse_string(request, "client_secret", required=True)
296 assert_valid_client_secret(client_secret)
295297 token = parse_string(request, "token", required=True)
296298
297299 # Attempt to validate a 3PID session
519521 # not this will raise a user-interactive auth error.
520522 try:
521523 auth_result, params, session_id = await self.auth_handler.check_ui_auth(
522 self._registration_flows, request, body, "register a new account",
524 self._registration_flows,
525 request,
526 body,
527 "register a new account",
523528 )
524529 except InteractiveAuthIncompleteError as e:
525530 # The user needs to provide more steps to complete auth.
662667 username, as_token
663668 )
664669 return await self._create_registration_details(
665 user_id, body, is_appservice_ghost=True,
670 user_id,
671 body,
672 is_appservice_ghost=True,
666673 )
667674
668675 async def _create_registration_details(
243243 requester = await self.auth.get_user_by_req(request, allow_guest=True)
244244
245245 await self.auth.check_user_in_room_or_world_readable(
246 room_id, requester.user.to_string(), allow_departed_users=True,
246 room_id,
247 requester.user.to_string(),
248 allow_departed_users=True,
247249 )
248250
249251 # This checks that a) the event exists and b) the user is allowed to
321323 requester = await self.auth.get_user_by_req(request, allow_guest=True)
322324
323325 await self.auth.check_user_in_room_or_world_readable(
324 room_id, requester.user.to_string(), allow_departed_users=True,
326 room_id,
327 requester.user.to_string(),
328 allow_departed_users=True,
325329 )
326330
327331 # This checks that a) the event exists and b) the user is allowed to
2929
3030
3131 class RoomUpgradeRestServlet(RestServlet):
32 """Handler for room uprade requests.
32 """Handler for room upgrade requests.
3333
3434 Handles requests of the form:
3535
136136 # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token`
137137 # is (essentially) a single US-ASCII word, and a `quoted-string` is a
138138 # US-ASCII string surrounded by double-quotes, using backslash as an
139 # escape charater. Note that %-encoding is *not* permitted.
139 # escape character. Note that %-encoding is *not* permitted.
140140 #
141141 # `filename*` is defined to be an `ext-value`, which is defined in
142142 # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`,
5050 b" object-src 'self';",
5151 )
5252 request.setHeader(
53 b"Referrer-Policy", b"no-referrer",
53 b"Referrer-Policy",
54 b"no-referrer",
5455 )
5556 server_name, media_id, name = parse_media_id(request)
5657 if server_name == self.server_name:
183183 async def get_local_media(
184184 self, request: Request, media_id: str, name: Optional[str]
185185 ) -> None:
186 """Responds to reqests for local media, if exists, or returns 404.
186 """Responds to requests for local media, if exists, or returns 404.
187187
188188 Args:
189189 request: The incoming request.
305305 media_info = await self.store.get_cached_remote_media(server_name, media_id)
306306
307307 # file_id is the ID we use to track the file locally. If we've already
308 # seen the file then reuse the existing ID, otherwise genereate a new
308 # seen the file then reuse the existing ID, otherwise generate a new
309309 # one.
310310
311311 # If we have an entry in the DB, try and look for it
324324 # Failed to find the file anywhere, lets download it.
325325
326326 try:
327 media_info = await self._download_remote_file(server_name, media_id,)
327 media_info = await self._download_remote_file(
328 server_name,
329 media_id,
330 )
328331 except SynapseError:
329332 raise
330333 except Exception as e:
350353 responder = await self.media_storage.fetch_media(file_info)
351354 return responder, media_info
352355
353 async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
356 async def _download_remote_file(
357 self,
358 server_name: str,
359 media_id: str,
360 ) -> dict:
354361 """Attempt to download the remote file from the given server name,
355362 using the given file_id as the local id.
356363
772779 )
773780 except Exception as e:
774781 thumbnail_exists = await self.store.get_remote_media_thumbnail(
775 server_name, media_id, t_width, t_height, t_type,
782 server_name,
783 media_id,
784 t_width,
785 t_height,
786 t_type,
776787 )
777788 if not thumbnail_exists:
778789 raise e
831842 return await self._remove_local_media_from_disk([media_id])
832843
833844 async def delete_old_local_media(
834 self, before_ts: int, size_gt: int = 0, keep_profiles: bool = True,
845 self,
846 before_ts: int,
847 size_gt: int = 0,
848 keep_profiles: bool = True,
835849 ) -> Tuple[List[str], int]:
836850 """
837851 Delete local or remote media from this server by size and timestamp. Removes
848862 A tuple of (list of deleted media IDs, total deleted media IDs).
849863 """
850864 old_media = await self.store.get_local_media_before(
851 before_ts, size_gt, keep_profiles,
865 before_ts,
866 size_gt,
867 keep_profiles,
852868 )
853869 return await self._remove_local_media_from_disk(old_media)
854870
926942
927943 <thumbnail>
928944
929 The thumbnail methods are "crop" and "scale". "scale" trys to return an
945 The thumbnail methods are "crop" and "scale". "scale" tries to return an
930946 image where either the width or the height is smaller than the requested
931947 size. The client should then scale and letterbox the image if it needs to
932 fit within a given rectangle. "crop" trys to return an image where the
948 fit within a given rectangle. "crop" tries to return an image where the
933949 width and height are close to the requested size and the aspect matches
934950 the requested size. The client should scale the image if it needs to fit
935951 within a given rectangle.
1515 import logging
1616 import os
1717 import shutil
18 from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
18 from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
19
20 import attr
1921
2022 from twisted.internet.defer import Deferred
2123 from twisted.internet.interfaces import IConsumer
2224 from twisted.protocols.basic import FileSender
2325
26 from synapse.api.errors import NotFoundError
2427 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
28 from synapse.util import Clock
2529 from synapse.util.file_consumer import BackgroundFileConsumer
2630
2731 from ._base import FileInfo, Responder
5761 self.local_media_directory = local_media_directory
5862 self.filepaths = filepaths
5963 self.storage_providers = storage_providers
64 self.spam_checker = hs.get_spam_checker()
65 self.clock = hs.get_clock()
6066
6167 async def store_file(self, source: IO, file_info: FileInfo) -> str:
6268 """Write `source` to the on disk media store, and also any other
7884 return fname
7985
8086 async def write_to_file(self, source: IO, output: IO):
81 """Asynchronously write the `source` to `output`.
82 """
87 """Asynchronously write the `source` to `output`."""
8388 await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
8489
8590 @contextlib.contextmanager
126131 f.flush()
127132 f.close()
128133
134 spam = await self.spam_checker.check_media_file_for_spam(
135 ReadableFileWrapper(self.clock, fname), file_info
136 )
137 if spam:
138 logger.info("Blocking media due to spam checker")
139 # Note that we'll delete the stored media, due to the
140 # try/except below. The media also won't be stored in
141 # the DB.
142 raise SpamMediaException()
143
129144 for provider in self.storage_providers:
130145 await provider.store_file(path, file_info)
131146
132147 finished_called[0] = True
133148
134149 yield f, fname, finish
135 except Exception:
150 except Exception as e:
136151 try:
137152 os.remove(fname)
138153 except Exception:
139154 pass
140 raise
155
156 raise e from None
141157
142158 if not finished_called:
143159 raise Exception("Finished callback not called")
301317
302318 def __exit__(self, exc_type, exc_val, exc_tb):
303319 self.open_file.close()
320
321
322 class SpamMediaException(NotFoundError):
323 """The media was blocked by a spam checker, so we simply 404 the request (in
324 the same way as if it was quarantined).
325 """
326
327
328 @attr.s(slots=True)
329 class ReadableFileWrapper:
330 """Wrapper that allows reading a file in chunks, yielding to the reactor,
331 and writing to a callback.
332
333 This is simplified `FileSender` that takes an IO object rather than an
334 `IConsumer`.
335 """
336
337 CHUNK_SIZE = 2 ** 14
338
339 clock = attr.ib(type=Clock)
340 path = attr.ib(type=str)
341
342 async def write_chunks_to(self, callback: Callable[[bytes], None]):
343 """Reads the file in chunks and calls the callback with each chunk."""
344
345 with open(self.path, "rb") as file:
346 while True:
347 chunk = file.read(self.CHUNK_SIZE)
348 if not chunk:
349 break
350
351 callback(chunk)
352
353 # We yield to the reactor by sleeping for 0 seconds.
354 await self.clock.sleep(0)
5757
5858 logger = logging.getLogger(__name__)
5959
60 _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
60 _charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I)
61 _xml_encoding_match = re.compile(
62 br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I
63 )
6164 _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
6265
6366 OG_TAG_NAME_MAXLEN = 50
299302 with open(media_info["filename"], "rb") as file:
300303 body = file.read()
301304
302 encoding = None
303
304 # Let's try and figure out if it has an encoding set in a meta tag.
305 # Limit it to the first 1kb, since it ought to be in the meta tags
306 # at the top.
307 match = _charset_match.search(body[:1000])
308
309 # If we find a match, it should take precedence over the
310 # Content-Type header, so set it here.
311 if match:
312 encoding = match.group(1).decode("ascii")
313
314 # If we don't find a match, we'll look at the HTTP Content-Type, and
315 # if that doesn't exist, we'll fall back to UTF-8.
316 if not encoding:
317 content_match = _content_type_match.match(media_info["media_type"])
318 encoding = content_match.group(1) if content_match else "utf-8"
319
305 encoding = get_html_media_encoding(body, media_info["media_type"])
320306 og = decode_and_calc_og(body, media_info["uri"], encoding)
321307
322308 # pre-cache the image for posterity
593579 )
594580
595581 async def _expire_url_cache_data(self) -> None:
596 """Clean up expired url cache content, media and thumbnails.
597 """
582 """Clean up expired url cache content, media and thumbnails."""
598583 # TODO: Delete from backup media store
599584
600585 assert self._worker_run_media_background_jobs
688673 logger.debug("No media removed from url cache")
689674
690675
676 def get_html_media_encoding(body: bytes, content_type: str) -> str:
677 """
678 Get the encoding of the body based on the (presumably) HTML body or media_type.
679
680 The precedence used for finding a character encoding is:
681
682 1. meta tag with a charset declared.
683 2. The XML document's character encoding attribute.
684 3. The Content-Type header.
685 4. Fallback to UTF-8.
686
687 Args:
688 body: The HTML document, as bytes.
689 content_type: The Content-Type header.
690
691 Returns:
692 The character encoding of the body, as a string.
693 """
694 # Limit searches to the first 1kb, since it ought to be at the top.
695 body_start = body[:1024]
696
697 # Let's try and figure out if it has an encoding set in a meta tag.
698 match = _charset_match.search(body_start)
699 if match:
700 return match.group(1).decode("ascii")
701
702 # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
703
704 # If we didn't find a match, see if it an XML document with an encoding.
705 match = _xml_encoding_match.match(body_start)
706 if match:
707 return match.group(1).decode("ascii")
708
709 # If we don't find a match, we'll look at the HTTP Content-Type, and
710 # if that doesn't exist, we'll fall back to UTF-8.
711 content_match = _content_type_match.match(content_type)
712 if content_match:
713 return content_match.group(1)
714
715 return "utf-8"
716
717
691718 def decode_and_calc_og(
692719 body: bytes, media_uri: str, request_encoding: Optional[str] = None
693720 ) -> Dict[str, Optional[str]]:
724751 def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
725752 # Attempt to parse the body. If this fails, log and return no metadata.
726753 tree = etree.fromstring(body_attempt, parser)
754
755 # The data was successfully parsed, but no tree was found.
756 if tree is None:
757 return {}
758
727759 return _calc_og(tree, media_uri)
728760
729761 # Attempt to parse the body. If this fails, log and return no metadata.
2121 from synapse.api.errors import Codes, SynapseError
2222 from synapse.http.server import DirectServeJsonResource, respond_with_json
2323 from synapse.http.servlet import parse_string
24 from synapse.rest.media.v1.media_storage import SpamMediaException
2425
2526 if TYPE_CHECKING:
2627 from synapse.app.homeserver import HomeServer
8586 # disposition = headers.getRawHeaders(b"Content-Disposition")[0]
8687 # TODO(markjh): parse content-dispostion
8788
88 content_uri = await self.media_repo.create_content(
89 media_type, upload_name, request.content, content_length, requester.user
90 )
89 try:
90 content_uri = await self.media_repo.create_content(
91 media_type, upload_name, request.content, content_length, requester.user
92 )
93 except SpamMediaException:
94 # For uploading of media we want to respond with a 400, instead of
95 # the default 404, as that would just be confusing.
96 raise SynapseError(400, "Bad content")
9197
9298 logger.info("Uploaded content with URI %r", content_uri)
9399
5757 resources["/_synapse/client/saml2"] = res
5858
5959 # This is also mounted under '/_matrix' for backwards-compatibility.
60 # To be removed in Synapse v1.32.0.
6061 resources["/_matrix/saml2"] = res
6162
6263 return resources
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14
1415 import logging
16 from typing import TYPE_CHECKING
1517
1618 from synapse.http.server import DirectServeHtmlResource
19
20 if TYPE_CHECKING:
21 from synapse.server import HomeServer
1722
1823 logger = logging.getLogger(__name__)
1924
2126 class OIDCCallbackResource(DirectServeHtmlResource):
2227 isLeaf = 1
2328
24 def __init__(self, hs):
29 def __init__(self, hs: "HomeServer"):
2530 super().__init__()
2631 self._oidc_handler = hs.get_oidc_handler()
2732
2833 async def _async_render_GET(self, request):
2934 await self._oidc_handler.handle_oidc_callback(request)
35
36 async def _async_render_POST(self, request):
37 # the auth response can be returned via an x-www-form-urlencoded form instead
38 # of GET params, as per
39 # https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.
40 await self._oidc_handler.handle_oidc_callback(request)
2424 import functools
2525 import logging
2626 import os
27 from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
27 from typing import (
28 TYPE_CHECKING,
29 Any,
30 Callable,
31 Dict,
32 List,
33 Optional,
34 TypeVar,
35 Union,
36 cast,
37 )
2838
2939 import twisted.internet.base
3040 import twisted.internet.tcp
587597 return UserDirectoryHandler(self)
588598
589599 @cache_in_self
590 def get_groups_local_handler(self):
600 def get_groups_local_handler(
601 self,
602 ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
591603 if self.config.worker_app:
592604 return GroupsLocalWorkerHandler(self)
593605 else:
2727
2828
2929 class ResourceLimitsServerNotices:
30 """ Keeps track of whether the server has reached it's resource limit and
30 """Keeps track of whether the server has reached it's resource limit and
3131 ensures that the client is kept up to date.
3232 """
3333
397397 async def resolve_state_groups_for_events(
398398 self, room_id: str, event_ids: Iterable[str]
399399 ) -> _StateCacheEntry:
400 """ Given a list of event_ids this method fetches the state at each
400 """Given a list of event_ids this method fetches the state at each
401401 event, resolves conflicts between them and returns them.
402402
403403 Args:
569569 return cache
570570
571571 logger.info(
572 "Resolving state for %s with groups %s", room_id, list(group_names),
572 "Resolving state for %s with groups %s",
573 room_id,
574 list(group_names),
573575 )
574576
575577 state_groups_histogram.observe(len(state_groups_ids))
614616 event_map:
615617 a dict from event_id to event, for any events that we happen to
616618 have in flight (eg, those currently being persisted). This will be
617 used as a starting point fof finding the state we need; any missing
619 used as a starting point for finding the state we need; any missing
618620 events will be requested via state_map_factory.
619621
620622 If None, all events will be fetched via state_res_store.
655657 return
656658
657659 self._report_biggest(
658 lambda i: i.cpu_time, "CPU time", _biggest_room_by_cpu_counter,
660 lambda i: i.cpu_time,
661 "CPU time",
662 _biggest_room_by_cpu_counter,
659663 )
660664
661665 self._report_biggest(
662 lambda i: i.db_time, "DB time", _biggest_room_by_db_counter,
666 lambda i: i.db_time,
667 "DB time",
668 _biggest_room_by_db_counter,
663669 )
664670
665671 self._state_res_metrics.clear()
9494 if event.room_id != room_id:
9595 raise Exception(
9696 "Attempting to state-resolve for room %s with event %s which is in %s"
97 % (room_id, event.event_id, event.room_id,)
97 % (
98 room_id,
99 event.event_id,
100 event.room_id,
101 )
98102 )
99103
100104 # get the ids of the auth events which allow us to authenticate the
118122 if event.room_id != room_id:
119123 raise Exception(
120124 "Attempting to state-resolve for room %s with event %s which is in %s"
121 % (room_id, event.event_id, event.room_id,)
125 % (
126 room_id,
127 event.event_id,
128 event.room_id,
129 )
122130 )
123131
124132 state_map.update(state_map_new)
242250 def _resolve_state_events(
243251 conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
244252 ) -> StateMap[EventBase]:
245 """ This is where we actually decide which of the conflicted state to
253 """This is where we actually decide which of the conflicted state to
246254 use.
247255
248256 We resolve conflicts in the following order:
117117 if event.room_id != room_id:
118118 raise Exception(
119119 "Attempting to state-resolve for room %s with event %s which is in %s"
120 % (room_id, event.event_id, event.room_id,)
120 % (
121 room_id,
122 event.event_id,
123 event.room_id,
124 )
121125 )
122126
123127 full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
4242
4343
4444 class Storage:
45 """The high level interfaces for talking to various storage layers.
46 """
45 """The high level interfaces for talking to various storage layers."""
4746
4847 def __init__(self, hs: "HomeServer", stores: Databases):
4948 # We include the main data store here mainly so that we don't have to
7676
7777
7878 class BackgroundUpdater:
79 """ Background updates are updates to the database that run in the
79 """Background updates are updates to the database that run in the
8080 background. Each update processes a batch of data at once. We attempt to
8181 limit the impact of each update by monitoring how long each batch takes to
8282 process and autotuning the batch size.
157157 return False
158158
159159 async def has_completed_background_update(self, update_name: str) -> bool:
160 """Check if the given background update has finished running.
161 """
160 """Check if the given background update has finished running."""
162161 if self._all_done:
163162 return True
164163
197196
198197 if not self._current_background_update:
199198 all_pending_updates = await self.db_pool.runInteraction(
200 "background_updates", get_background_updates_txn,
199 "background_updates",
200 get_background_updates_txn,
201201 )
202202 if not all_pending_updates:
203203 # no work left to do
8484 def make_pool(
8585 reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
8686 ) -> adbapi.ConnectionPool:
87 """Get the connection pool for the database.
88 """
87 """Get the connection pool for the database."""
8988
9089 # By default enable `cp_reconnect`. We need to fiddle with db_args in case
9190 # someone has explicitly set `cp_reconnect`.
157156 def commit(self) -> None:
158157 self.conn.commit()
159158
160 def rollback(self, *args, **kwargs) -> None:
161 self.conn.rollback(*args, **kwargs)
159 def rollback(self) -> None:
160 self.conn.rollback()
162161
163162 def __enter__(self) -> "Connection":
164163 self.conn.__enter__()
243242 assert self.exception_callbacks is not None
244243 self.exception_callbacks.append((callback, args, kwargs))
245244
245 def fetchone(self) -> Optional[Tuple]:
246 return self.txn.fetchone()
247
248 def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
249 return self.txn.fetchmany(size=size)
250
246251 def fetchall(self) -> List[Tuple]:
247252 return self.txn.fetchall()
248
249 def fetchone(self) -> Tuple:
250 return self.txn.fetchone()
251253
252254 def __iter__(self) -> Iterator[Tuple]:
253255 return self.txn.__iter__()
428430 )
429431
430432 def is_running(self) -> bool:
431 """Is the database pool currently running
432 """
433 """Is the database pool currently running"""
433434 return self._db_pool.running
434435
435436 async def _check_safe_to_upsert(self) -> None:
542543 # This can happen if the database disappears mid
543544 # transaction.
544545 transaction_logger.warning(
545 "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
546 "[TXN OPERROR] {%s} %s %d/%d",
547 name,
548 e,
549 i,
550 N,
546551 )
547552 if i < N:
548553 i += 1
563568 conn.rollback()
564569 except self.engine.module.Error as e1:
565570 transaction_logger.warning(
566 "[TXN EROLL] {%s} %s", name, e1,
571 "[TXN EROLL] {%s} %s",
572 name,
573 e1,
567574 )
568575 continue
569576 raise
753760 Returns:
754761 A list of dicts where the key is the column header.
755762 """
763 assert cursor.description is not None, "cursor.description was None"
756764 col_headers = [intern(str(column[0])) for column in cursor.description]
757765 results = [dict(zip(col_headers, row)) for row in cursor]
758766 return results
14011409
14021410 @staticmethod
14031411 def simple_select_onecol_txn(
1404 txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
1412 txn: LoggingTransaction,
1413 table: str,
1414 keyvalues: Dict[str, Any],
1415 retcol: str,
14051416 ) -> List[Any]:
14061417 sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
14071418
17111722 desc: description of the transaction, for logging and metrics
17121723 """
17131724 await self.runInteraction(
1714 desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
1725 desc,
1726 self.simple_delete_one_txn,
1727 table,
1728 keyvalues,
1729 db_autocommit=True,
17151730 )
17161731
17171732 @staticmethod
5555 database_config.databases,
5656 )
5757 prepare_database(
58 db_conn, engine, hs.config, databases=database_config.databases,
58 db_conn,
59 engine,
60 hs.config,
61 databases=database_config.databases,
5962 )
6063
6164 database = DatabasePool(hs, database_config, engine)
339339 count = txn.fetchone()[0]
340340
341341 sql = (
342 "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
342 "SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url "
343343 + sql_base
344344 + " ORDER BY u.name LIMIT ? OFFSET ?"
345345 )
7272 return self.services_cache
7373
7474 def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
75 """Check if the user is one associated with an app service (exclusively)
76 """
75 """Check if the user is one associated with an app service (exclusively)"""
7776 if self.exclusive_user_regex:
7877 return bool(self.exclusive_user_regex.match(user_id))
7978 else:
279279 return batch_size
280280
281281 async def _devices_last_seen_update(self, progress, batch_size):
282 """Background update to insert last seen info into devices table
283 """
282 """Background update to insert last seen info into devices table"""
284283
285284 last_user_id = progress.get("last_user_id", "")
286285 last_device_id = progress.get("last_device_id", "")
362361
363362 @wrap_as_background_process("prune_old_user_ips")
364363 async def _prune_old_user_ips(self):
365 """Removes entries in user IPs older than the configured period.
366 """
364 """Removes entries in user IPs older than the configured period."""
367365
368366 if self.user_ips_max_age is None:
369367 # Nothing to do
564562 results = {}
565563
566564 for key in self._batch_row_update:
567 uid, access_token, ip, = key
565 (
566 uid,
567 access_token,
568 ip,
569 ) = key
568570 if uid == user_id:
569571 user_agent, _, last_seen = self._batch_row_update[key]
570572 results[(access_token, ip)] = (user_agent, last_seen)
449449 },
450450 )
451451
452 # Add the messages to the approriate local device inboxes so that
452 # Add the messages to the appropriate local device inboxes so that
453453 # they'll be sent to the devices when they next sync.
454454 self._add_messages_to_local_device_inbox_txn(
455455 txn, stream_id, local_messages_by_user_then_device
314314
315315 # make sure we go through the devices in stream order
316316 device_ids = sorted(
317 user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
317 user_devices.keys(),
318 key=lambda i: query_map[(user_id, i)][0],
318319 )
319320
320321 for device_id in device_ids:
365366 async def mark_as_sent_devices_by_remote(
366367 self, destination: str, stream_id: int
367368 ) -> None:
368 """Mark that updates have successfully been sent to the destination.
369 """
369 """Mark that updates have successfully been sent to the destination."""
370370 await self.db_pool.runInteraction(
371371 "mark_as_sent_devices_by_remote",
372372 self._mark_as_sent_devices_by_remote_txn,
680680 return results
681681
682682 async def get_user_ids_requiring_device_list_resync(
683 self, user_ids: Optional[Collection[str]] = None,
683 self,
684 user_ids: Optional[Collection[str]] = None,
684685 ) -> Set[str]:
685686 """Given a list of remote users return the list of users that we
686687 should resync the device lists for. If None is given instead of a list,
720721 )
721722
722723 async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
723 """Mark that we no longer track device lists for remote user.
724 """
724 """Mark that we no longer track device lists for remote user."""
725725
726726 def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
727727 self.db_pool.simple_delete_txn(
901901 logger.info("Pruned %d device list outbound pokes", count)
902902
903903 await self.db_pool.runInteraction(
904 "_prune_old_outbound_device_pokes", _prune_txn,
904 "_prune_old_outbound_device_pokes",
905 _prune_txn,
905906 )
906907
907908
942943
943944 # clear out duplicate device list outbound pokes
944945 self.db_pool.updates.register_background_update_handler(
945 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
946 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
947 self._remove_duplicate_outbound_pokes,
946948 )
947949
948950 # a pair of background updates that were added during the 1.14 release cycle,
10031005 row = None
10041006 for row in rows:
10051007 self.db_pool.simple_delete_txn(
1006 txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
1008 txn,
1009 "device_lists_outbound_pokes",
1010 {x: row[x] for x in KEY_COLS},
10071011 )
10081012
10091013 row["sent"] = False
10101014 self.db_pool.simple_insert_txn(
1011 txn, "device_lists_outbound_pokes", row,
1015 txn,
1016 "device_lists_outbound_pokes",
1017 row,
10121018 )
10131019
10141020 if row:
10151021 self.db_pool.updates._background_update_progress_txn(
1016 txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
1022 txn,
1023 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
1024 {"last_row": row},
10171025 )
10181026
10191027 return len(rows)
12851293 # we've done a full resync, so we remove the entry that says we need
12861294 # to resync
12871295 self.db_pool.simple_delete_txn(
1288 txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
1296 txn,
1297 table="device_lists_remote_resync",
1298 keyvalues={"user_id": user_id},
12891299 )
12901300
12911301 async def add_device_change_to_streams(
13351345 stream_ids: List[str],
13361346 ):
13371347 txn.call_after(
1338 self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
1348 self._device_list_stream_cache.entity_has_changed,
1349 user_id,
1350 stream_ids[-1],
13391351 )
13401352
13411353 min_stream_id = stream_ids[0]
8484 servers: Iterable[str],
8585 creator: Optional[str] = None,
8686 ) -> None:
87 """ Creates an association between a room alias and room_id/servers
87 """Creates an association between a room alias and room_id/servers
8888
8989 Args:
9090 room_alias: The alias to create.
159159 return room_id
160160
161161 async def update_aliases_for_room(
162 self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
162 self,
163 old_room_id: str,
164 new_room_id: str,
165 creator: Optional[str] = None,
163166 ) -> None:
164167 """Repoint all of the aliases for a given room, to a different room.
165168
360360 async def count_e2e_one_time_keys(
361361 self, user_id: str, device_id: str
362362 ) -> Dict[str, int]:
363 """ Count the number of one time keys the server has for a device
363 """Count the number of one time keys the server has for a device
364364 Returns:
365365 A mapping from algorithm to number of keys for that algorithm.
366366 """
493493 )
494494
495495 def _get_bare_e2e_cross_signing_keys_bulk_txn(
496 self, txn: Connection, user_ids: List[str],
496 self,
497 txn: Connection,
498 user_ids: List[str],
497499 ) -> Dict[str, Dict[str, dict]]:
498500 """Returns the cross-signing keys for a set of users. The output of this
499501 function should be passed to _get_e2e_cross_signing_signatures_txn if
555557 return result
556558
557559 def _get_e2e_cross_signing_signatures_txn(
558 self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
560 self,
561 txn: Connection,
562 keys: Dict[str, Dict[str, dict]],
563 from_user_id: str,
559564 ) -> Dict[str, Dict[str, dict]]:
560565 """Returns the cross-signing signatures made by a user on a set of keys.
561566
7070 return await self.get_events_as_list(event_ids)
7171
7272 async def get_auth_chain_ids(
73 self, event_ids: Collection[str], include_given: bool = False,
73 self,
74 event_ids: Collection[str],
75 include_given: bool = False,
7476 ) -> List[str]:
7577 """Get auth events for given event_ids. The events *must* be state events.
7678
272274 # origin chain.
273275 if origin_sequence_number <= chains.get(origin_chain_id, 0):
274276 chains[target_chain_id] = max(
275 target_sequence_number, chains.get(target_chain_id, 0),
277 target_sequence_number,
278 chains.get(target_chain_id, 0),
276279 )
277280
278281 seen_chains.add(target_chain_id)
370373 # and state sets {A} and {B} then walking the auth chains of A and B
371374 # would immediately show that C is reachable by both. However, if we
372375 # stopped at C then we'd only reach E via the auth chain of B and so E
373 # would errornously get included in the returned difference.
376 # would erroneously get included in the returned difference.
374377 #
375378 # The other thing that we do is limit the number of auth chains we walk
376379 # at once, due to practical limits (i.e. we can only query the database
496499
497500 a_ids = new_aids
498501
499 # Mark that the auth event is reachable by the approriate sets.
502 # Mark that the auth event is reachable by the appropriate sets.
500503 sets.intersection_update(event_to_missing_sets[event_id])
501504
502505 search.sort()
631634 )
632635
633636 async def get_min_depth(self, room_id: str) -> int:
634 """For the given room, get the minimum depth we have seen for it.
635 """
637 """For the given room, get the minimum depth we have seen for it."""
636638 return await self.db_pool.runInteraction(
637639 "get_min_depth", self._get_min_depth_interaction, room_id
638640 )
857859 )
858860
859861 await self.db_pool.runInteraction(
860 "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
862 "_delete_old_forward_extrem_cache",
863 _delete_old_forward_extrem_cache_txn,
861864 )
862865
863866
864867 class EventFederationStore(EventFederationWorkerStore):
865 """ Responsible for storing and serving up the various graphs associated
868 """Responsible for storing and serving up the various graphs associated
866869 with an event. Including the main event graph and the auth chains for an
867870 event.
868871
5353
5454
5555 def _deserialize_action(actions, is_highlight):
56 """Custom deserializer for actions. This allows us to "compress" common actions
57 """
56 """Custom deserializer for actions. This allows us to "compress" common actions"""
5857 if actions:
5958 return db_to_json(actions)
6059
9089
9190 @cached(num_args=3, tree=True, max_entries=5000)
9291 async def get_unread_event_push_actions_by_room_for_user(
93 self, room_id: str, user_id: str, last_read_event_id: Optional[str],
92 self,
93 room_id: str,
94 user_id: str,
95 last_read_event_id: Optional[str],
9496 ) -> Dict[str, int]:
9597 """Get the notification count, the highlight count and the unread message count
9698 for a given user in a given room after the given read receipt.
119121 )
120122
121123 def _get_unread_counts_by_receipt_txn(
122 self, txn, room_id, user_id, last_read_event_id,
124 self,
125 txn,
126 room_id,
127 user_id,
128 last_read_event_id,
123129 ):
124130 stream_ordering = None
125131
126132 if last_read_event_id is not None:
127133 stream_ordering = self.get_stream_id_for_event_txn(
128 txn, last_read_event_id, allow_none=True,
134 txn,
135 last_read_event_id,
136 allow_none=True,
129137 )
130138
131139 if stream_ordering is None:
398398 self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
399399
400400 def _persist_event_auth_chain_txn(
401 self, txn: LoggingTransaction, events: List[EventBase],
401 self,
402 txn: LoggingTransaction,
403 events: List[EventBase],
402404 ) -> None:
403405
404406 # We only care about state events, so this if there are no state events.
469471 event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
470472
471473 self._add_chain_cover_index(
472 txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
474 txn,
475 self.db_pool,
476 event_to_room_id,
477 event_to_types,
478 event_to_auth_chain,
473479 )
474480
475481 @classmethod
516522 # simple_select_many, but this case happens rarely and almost always
517523 # with a single row.)
518524 auth_events = db_pool.simple_select_onecol_txn(
519 txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
525 txn,
526 "event_auth",
527 keyvalues={"event_id": event_id},
528 retcol="auth_id",
520529 )
521530
522531 events_to_calc_chain_id_for.add(event_id)
549558 WHERE
550559 """
551560 clause, args = make_in_list_sql_clause(
552 txn.database_engine, "event_id", missing_auth_chains,
561 txn.database_engine,
562 "event_id",
563 missing_auth_chains,
553564 )
554565 txn.execute(sql + clause, args)
555566
703714 if chain_map[a_id][0] != chain_id
704715 }
705716 for start_auth_id, end_auth_id in itertools.permutations(
706 event_to_auth_chain.get(event_id, []), r=2,
717 event_to_auth_chain.get(event_id, []),
718 r=2,
707719 ):
708720 if chain_links.exists_path_from(
709721 chain_map[start_auth_id], chain_map[end_auth_id]
887899 txn: LoggingTransaction,
888900 events_and_contexts: List[Tuple[EventBase, EventContext]],
889901 ):
890 """Persist the mapping from transaction IDs to event IDs (if defined).
891 """
902 """Persist the mapping from transaction IDs to event IDs (if defined)."""
892903
893904 to_insert = []
894905 for event, _ in events_and_contexts:
908919
909920 if to_insert:
910921 self.db_pool.simple_insert_many_txn(
911 txn, table="event_txn_id", values=to_insert,
922 txn,
923 table="event_txn_id",
924 values=to_insert,
912925 )
913926
914927 def _update_current_state_txn(
940953 txn.execute(sql, (stream_id, self._instance_name, room_id))
941954
942955 self.db_pool.simple_delete_txn(
943 txn, table="current_state_events", keyvalues={"room_id": room_id},
956 txn,
957 table="current_state_events",
958 keyvalues={"room_id": room_id},
944959 )
945960 else:
946961 # We're still in the room, so we update the current state as normal.
10491064 # Figure out the changes of membership to invalidate the
10501065 # `get_rooms_for_user` cache.
10511066 # We find out which membership events we may have deleted
1052 # and which we have added, then we invlidate the caches for all
1067 # and which we have added, then we invalidate the caches for all
10531068 # those users.
10541069 members_changed = {
10551070 state_key
16071622 )
16081623
16091624 def _store_room_members_txn(self, txn, events, backfilled):
1610 """Store a room member in the database.
1611 """
1625 """Store a room member in the database."""
16121626
16131627 def str_or_none(val: Any) -> Optional[str]:
16141628 return val if isinstance(val, str) else None
20002014
20012015 @attr.s(slots=True)
20022016 class _LinkMap:
2003 """A helper type for tracking links between chains.
2004 """
2017 """A helper type for tracking links between chains."""
20052018
20062019 # Stores the set of links as nested maps: source chain ID -> target chain ID
20072020 # -> source sequence number -> target sequence number.
21072120 yield (src_chain, src_seq, target_chain, target_seq)
21082121
21092122 def exists_path_from(
2110 self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
2123 self,
2124 src_tuple: Tuple[int, int],
2125 target_tuple: Tuple[int, int],
21112126 ) -> bool:
21122127 """Checks if there is a path between the source chain ID/sequence and
21132128 target chain ID/sequence.
3131
3232 @attr.s(slots=True, frozen=True)
3333 class _CalculateChainCover:
34 """Return value for _calculate_chain_cover_txn.
35 """
34 """Return value for _calculate_chain_cover_txn."""
3635
3736 # The last room_id/depth/stream processed.
3837 room_id = attr.ib(type=str)
126125 )
127126
128127 self.db_pool.updates.register_background_update_handler(
129 "rejected_events_metadata", self._rejected_events_metadata,
128 "rejected_events_metadata",
129 self._rejected_events_metadata,
130130 )
131131
132132 self.db_pool.updates.register_background_update_handler(
133 "chain_cover", self._chain_cover_index,
133 "chain_cover",
134 self._chain_cover_index,
134135 )
135136
136137 async def _background_reindex_fields_sender(self, progress, batch_size):
461462 return num_handled
462463
463464 async def _redactions_received_ts(self, progress, batch_size):
464 """Handles filling out the `received_ts` column in redactions.
465 """
465 """Handles filling out the `received_ts` column in redactions."""
466466 last_event_id = progress.get("last_event_id", "")
467467
468468 def _redactions_received_ts_txn(txn):
517517 return count
518518
519519 async def _event_fix_redactions_bytes(self, progress, batch_size):
520 """Undoes hex encoded censored redacted event JSON.
521 """
520 """Undoes hex encoded censored redacted event JSON."""
522521
523522 def _event_fix_redactions_bytes_txn(txn):
524523 # This update is quite fast due to new index.
641640 LIMIT ?
642641 """
643642
644 txn.execute(sql, (last_event_id, batch_size,))
643 txn.execute(
644 sql,
645 (
646 last_event_id,
647 batch_size,
648 ),
649 )
645650
646651 return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
647652
909914 # Annoyingly we need to gut wrench into the persit event store so that
910915 # we can reuse the function to calculate the chain cover for rooms.
911916 PersistEventsStore._add_chain_cover_index(
912 txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
917 txn,
918 self.db_pool,
919 event_to_room_id,
920 event_to_types,
921 event_to_auth_chain,
913922 )
914923
915924 return _CalculateChainCover(
7070 if txn.rowcount > 0:
7171 # Invalidate the cache
7272 self._invalidate_cache_and_stream(
73 txn, self.get_latest_event_ids_in_room, (room_id,),
73 txn,
74 self.get_latest_event_ids_in_room,
75 (room_id,),
7476 )
7577
7678 return txn.rowcount
9698 return self.db_pool.cursor_to_dict(txn)
9799
98100 return await self.db_pool.runInteraction(
99 "get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
101 "get_forward_extremities_for_room",
102 get_forward_extremities_for_room_txn,
100103 )
119119 # SQLite).
120120 if hs.get_instance_name() in hs.config.worker.writers.events:
121121 self._stream_id_gen = StreamIdGenerator(
122 db_conn, "events", "stream_ordering",
122 db_conn,
123 "events",
124 "stream_ordering",
123125 )
124126 self._backfill_id_gen = StreamIdGenerator(
125127 db_conn,
139141 if hs.config.run_background_tasks:
140142 # We periodically clean out old transaction ID mappings
141143 self._clock.looping_call(
142 self._cleanup_old_transaction_ids, 5 * 60 * 1000,
144 self._cleanup_old_transaction_ids,
145 5 * 60 * 1000,
143146 )
144147
145148 self._get_event_cache = LruCache(
13241327 return rows, to_token, True
13251328
13261329 async def is_event_after(self, event_id1, event_id2):
1327 """Returns True if event_id1 is after event_id2 in the stream
1328 """
1330 """Returns True if event_id1 is after event_id2 in the stream"""
13291331 to_1, so_1 = await self.get_event_ordering(event_id1)
13301332 to_2, so_2 = await self.get_event_ordering(event_id2)
13311333 return (to_1, so_1) > (to_2, so_2)
14271429
14281430 @wrap_as_background_process("_cleanup_old_transaction_ids")
14291431 async def _cleanup_old_transaction_ids(self):
1430 """Cleans out transaction id mappings older than 24hrs.
1431 """
1432 """Cleans out transaction id mappings older than 24hrs."""
14321433
14331434 def _cleanup_old_transaction_ids_txn(txn):
14341435 sql = """
14391440 txn.execute(sql, (one_day_ago,))
14401441
14411442 return await self.db_pool.runInteraction(
1442 "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
1443 )
1443 "_cleanup_old_transaction_ids",
1444 _cleanup_old_transaction_ids_txn,
1445 )
1313 # See the License for the specific language governing permissions and
1414 # limitations under the License.
1515
16 from typing import Any, Dict, List, Optional, Tuple, Union
16 from typing import Any, Dict, List, Optional, Tuple
17
18 from typing_extensions import TypedDict
1719
1820 from synapse.api.errors import SynapseError
1921 from synapse.storage._base import SQLBaseStore, db_to_json
2426 # database to avoid the fun of null != null
2527 _DEFAULT_CATEGORY_ID = ""
2628 _DEFAULT_ROLE_ID = ""
29
30 # A room in a group.
31 _RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
2732
2833
2934 class GroupServerWorkerStore(SQLBaseStore):
7176
7277 async def get_rooms_in_group(
7378 self, group_id: str, include_private: bool = False
74 ) -> List[Dict[str, Union[str, bool]]]:
79 ) -> List[_RoomInGroup]:
7580 """Retrieve the rooms that belong to a given group. Does not return rooms that
7681 lack members.
7782
122127 )
123128
124129 async def get_rooms_for_summary_by_category(
125 self, group_id: str, include_private: bool = False,
130 self,
131 group_id: str,
132 include_private: bool = False,
126133 ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
127134 """Get the rooms and categories that should be included in a summary request
128135
367374 async def is_user_invited_to_local_group(
368375 self, group_id: str, user_id: str
369376 ) -> Optional[bool]:
370 """Has the group server invited a user?
371 """
377 """Has the group server invited a user?"""
372378 return await self.db_pool.simple_select_one_onecol(
373379 table="group_invites",
374380 keyvalues={"group_id": group_id, "user_id": user_id},
426432 )
427433
428434 async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
429 """Get all groups a user is publicising
430 """
435 """Get all groups a user is publicising"""
431436 return await self.db_pool.simple_select_onecol(
432437 table="local_group_membership",
433438 keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
436441 )
437442
438443 async def get_attestations_need_renewals(self, valid_until_ms):
439 """Get all attestations that need to be renewed until givent time
440 """
444 """Get all attestations that need to be renewed until givent time"""
441445
442446 def _get_attestations_need_renewals_txn(txn):
443447 sql = """
780784 profile: Optional[JsonDict],
781785 is_public: Optional[bool],
782786 ) -> None:
783 """Add/update room category for group
784 """
787 """Add/update room category for group"""
785788 insertion_values = {}
786789 update_values = {"category_id": category_id} # This cannot be empty
787790
817820 profile: Optional[JsonDict],
818821 is_public: Optional[bool],
819822 ) -> None:
820 """Add/remove user role
821 """
823 """Add/remove user role"""
822824 insertion_values = {}
823825 update_values = {"role_id": role_id} # This cannot be empty
824826
10111013 )
10121014
10131015 async def add_group_invite(self, group_id: str, user_id: str) -> None:
1014 """Record that the group server has invited a user
1015 """
1016 """Record that the group server has invited a user"""
10161017 await self.db_pool.simple_insert(
10171018 table="group_invites",
10181019 values={"group_id": group_id, "user_id": user_id},
11551156 async def update_group_publicity(
11561157 self, group_id: str, user_id: str, publicise: bool
11571158 ) -> None:
1158 """Update whether the user is publicising their membership of the group
1159 """
1159 """Update whether the user is publicising their membership of the group"""
11601160 await self.db_pool.simple_update_one(
11611161 table="local_group_membership",
11621162 keyvalues={"group_id": group_id, "user_id": user_id},
12991299 async def update_attestation_renewal(
13001300 self, group_id: str, user_id: str, attestation: dict
13011301 ) -> None:
1302 """Update an attestation that we have renewed
1303 """
1302 """Update an attestation that we have renewed"""
13041303 await self.db_pool.simple_update_one(
13051304 table="group_attestations_renewals",
13061305 keyvalues={"group_id": group_id, "user_id": user_id},
13111310 async def update_remote_attestion(
13121311 self, group_id: str, user_id: str, attestation: dict
13131312 ) -> None:
1314 """Update an attestation that a remote has renewed
1315 """
1313 """Update an attestation that a remote has renewed"""
13161314 await self.db_pool.simple_update_one(
13171315 table="group_attestations_remote",
13181316 keyvalues={"group_id": group_id, "user_id": user_id},
3232
3333
3434 class KeyStore(SQLBaseStore):
35 """Persistence for signature verification keys
36 """
35 """Persistence for signature verification keys"""
3736
3837 @cached()
3938 def _get_server_verify_key(self, server_name_and_key_id):
154153 (server_name, key_id, from_server) triplet if one already existed.
155154 Args:
156155 server_name: The name of the server.
157 key_id: The identifer of the key this JSON is for.
156 key_id: The identifier of the key this JSON is for.
158157 from_server: The server this JSON was fetched from.
159158 ts_now_ms: The time now in milliseconds.
160159 ts_valid_until_ms: The time when this json stops being valid.
181180 async def get_server_keys_json(
182181 self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
183182 ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
184 """Retrive the key json for a list of server_keys and key ids.
183 """Retrieve the key json for a list of server_keys and key ids.
185184 If no keys are found for a given server, key_id and source then
186185 that server, key_id, and source triplet entry will be an empty list.
187186 The JSON is returned as a byte array so that it can be efficiently
168168 )
169169
170170 async def get_local_media_before(
171 self, before_ts: int, size_gt: int, keep_profiles: bool,
171 self,
172 before_ts: int,
173 size_gt: int,
174 keep_profiles: bool,
172175 ) -> List[str]:
173176
174177 # to find files that have never been accessed (last_access_ts IS NULL)
453456 )
454457
455458 async def get_remote_media_thumbnail(
456 self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
459 self,
460 origin: str,
461 media_id: str,
462 t_width: int,
463 t_height: int,
464 t_type: str,
457465 ) -> Optional[Dict[str, Any]]:
458 """Fetch the thumbnail info of given width, height and type.
459 """
466 """Fetch the thumbnail info of given width, height and type."""
460467
461468 return await self.db_pool.simple_select_one(
462469 table="remote_media_cache_thumbnails",
110110 async def count_daily_sent_e2ee_messages(self):
111111 def _count_messages(txn):
112112 # This is good enough as if you have silly characters in your own
113 # hostname then thats your own fault.
113 # hostname then that's your own fault.
114114 like_clause = "%:" + self.hs.hostname
115115
116116 sql = """
166166 async def count_daily_sent_messages(self):
167167 def _count_messages(txn):
168168 # This is good enough as if you have silly characters in your own
169 # hostname then thats your own fault.
169 # hostname then that's your own fault.
170170 like_clause = "%:" + self.hs.hostname
171171
172172 sql = """
129129 raise NotImplementedError()
130130
131131 @cachedList(
132 cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
132 cached_method_name="_get_presence_for_user",
133 list_name="user_ids",
134 num_args=1,
133135 )
134136 async def get_presence_for_users(self, user_ids):
135137 rows = await self.db_pool.simple_select_many_batch(
117117 )
118118
119119 async def is_subscribed_remote_profile_for_user(self, user_id):
120 """Check whether we are interested in a remote user's profile.
121 """
120 """Check whether we are interested in a remote user's profile."""
122121 res = await self.db_pool.simple_select_one_onecol(
123122 table="group_users",
124123 keyvalues={"user_id": user_id},
144143 async def get_remote_profile_cache_entries_that_expire(
145144 self, last_checked: int
146145 ) -> List[Dict[str, str]]:
147 """Get all users who haven't been checked since `last_checked`
148 """
146 """Get all users who haven't been checked since `last_checked`"""
149147
150148 def _get_remote_profile_cache_entries_that_expire_txn(txn):
151149 sql = """
167167 )
168168
169169 @cachedList(
170 cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
170 cached_method_name="get_push_rules_for_user",
171 list_name="user_ids",
172 num_args=1,
171173 )
172174 async def bulk_get_push_rules(self, user_ids):
173175 if not user_ids:
194196 use_new_defaults = user_id in self._users_new_default_push_rules
195197
196198 results[user_id] = _load_rules(
197 rules, enabled_map_by_user.get(user_id, {}), use_new_defaults,
199 rules,
200 enabled_map_by_user.get(user_id, {}),
201 use_new_defaults,
198202 )
199203
200204 return results
178178 raise NotImplementedError()
179179
180180 @cachedList(
181 cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
181 cached_method_name="get_if_user_has_pusher",
182 list_name="user_ids",
183 num_args=1,
182184 )
183185 async def get_if_users_have_pushers(
184186 self, user_ids: Iterable[str]
262264 params_by_room = {}
263265 for row in res:
264266 params_by_room[row["room_id"]] = ThrottleParams(
265 row["last_sent_ts"], row["throttle_ms"],
267 row["last_sent_ts"],
268 row["throttle_ms"],
266269 )
267270
268271 return params_by_room
159159
160160 Args:
161161 room_id: List of room_ids.
162 to_key: Max stream id to fetch receipts upto.
162 to_key: Max stream id to fetch receipts up to.
163163 from_key: Min stream id to fetch receipts from. None fetches
164164 from the start.
165165
188188
189189 Args:
190190 room_ids: The room id.
191 to_key: Max stream id to fetch receipts upto.
191 to_key: Max stream id to fetch receipts up to.
192192 from_key: Min stream id to fetch receipts from. None fetches
193193 from the start.
194194
207207 async def _get_linearized_receipts_for_room(
208208 self, room_id: str, to_key: int, from_key: Optional[int] = None
209209 ) -> List[dict]:
210 """See get_linearized_receipts_for_room
211 """
210 """See get_linearized_receipts_for_room"""
212211
213212 def f(txn):
214213 if from_key:
303302 }
304303 return results
305304
306 @cached(num_args=2,)
305 @cached(
306 num_args=2,
307 )
307308 async def get_linearized_receipts_for_all_rooms(
308309 self, to_key: int, from_key: Optional[int] = None
309310 ) -> Dict[str, JsonDict]:
311312 to a limit of the latest 100 read receipts.
312313
313314 Args:
314 to_key: Max stream id to fetch receipts upto.
315 to_key: Max stream id to fetch receipts up to.
315316 from_key: Min stream id to fetch receipts from. None fetches
316317 from the start.
317318
7878 # call `find_max_generated_user_id_localpart` each time, which is
7979 # expensive if there are many entries.
8080 self._user_id_seq = build_sequence_generator(
81 database.engine, find_max_generated_user_id_localpart, "user_id_seq",
81 database.engine,
82 find_max_generated_user_id_localpart,
83 "user_id_seq",
8284 )
8385
8486 self._account_validity = hs.config.account_validity
8587 if hs.config.run_background_tasks and self._account_validity.enabled:
8688 self._clock.call_later(
87 0.0, self._set_expiration_date_when_missing,
89 0.0,
90 self._set_expiration_date_when_missing,
8891 )
8992
9093 # Create a background job for culling expired 3PID validity tokens
109112 "creation_ts",
110113 "user_type",
111114 "deactivated",
115 "shadow_banned",
112116 ],
113117 allow_none=True,
114118 desc="get_user_by_id",
368372 """
369373
370374 def set_shadow_banned_txn(txn):
375 user_id = user.to_string()
371376 self.db_pool.simple_update_one_txn(
372377 txn,
373378 table="users",
374 keyvalues={"name": user.to_string()},
379 keyvalues={"name": user_id},
375380 updatevalues={"shadow_banned": shadow_banned},
376381 )
377382 # In order for this to apply immediately, clear the cache for this user.
378383 tokens = self.db_pool.simple_select_onecol_txn(
379384 txn,
380385 table="access_tokens",
381 keyvalues={"user_id": user.to_string()},
386 keyvalues={"user_id": user_id},
382387 retcol="token",
383388 )
384389 for token in tokens:
385390 self._invalidate_cache_and_stream(
386391 txn, self.get_user_by_access_token, (token,)
387392 )
393 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
388394
389395 await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
390396
192192 )
193193
194194 async def get_room_count(self) -> int:
195 """Retrieve the total number of rooms.
196 """
195 """Retrieve the total number of rooms."""
197196
198197 def f(txn):
199198 sql = "SELECT count(*) FROM rooms"
516515 return rooms, room_count[0]
517516
518517 return await self.db_pool.runInteraction(
519 "get_rooms_paginate", _get_rooms_paginate_txn,
518 "get_rooms_paginate",
519 _get_rooms_paginate_txn,
520520 )
521521
522522 @cached(max_entries=10000)
577577 return self.db_pool.cursor_to_dict(txn)
578578
579579 ret = await self.db_pool.runInteraction(
580 "get_retention_policy_for_room", get_retention_policy_for_room_txn,
580 "get_retention_policy_for_room",
581 get_retention_policy_for_room_txn,
581582 )
582583
583584 # If we don't know this room ID, ret will be None, in this case return the default
706707 return local_media_mxcs, remote_media_mxcs
707708
708709 async def quarantine_media_by_id(
709 self, server_name: str, media_id: str, quarantined_by: str,
710 self,
711 server_name: str,
712 media_id: str,
713 quarantined_by: str,
710714 ) -> int:
711715 """quarantines a single local or remote media id
712716
960964 self.config = hs.config
961965
962966 self.db_pool.updates.register_background_update_handler(
963 "insert_room_retention", self._background_insert_retention,
967 "insert_room_retention",
968 self._background_insert_retention,
964969 )
965970
966971 self.db_pool.updates.register_background_update_handler(
10321037 return False
10331038
10341039 end = await self.db_pool.runInteraction(
1035 "insert_room_retention", _background_insert_retention_txn,
1040 "insert_room_retention",
1041 _background_insert_retention_txn,
10361042 )
10371043
10381044 if end:
10431049 async def _background_add_rooms_room_version_column(
10441050 self, progress: dict, batch_size: int
10451051 ):
1046 """Background update to go and add room version inforamtion to `rooms`
1052 """Background update to go and add room version information to `rooms`
10471053 table from `current_state_events` table.
10481054 """
10491055
15871593 LIMIT ?
15881594 OFFSET ?
15891595 """.format(
1590 where_clause=where_clause, order=order,
1596 where_clause=where_clause,
1597 order=order,
15911598 )
15921599
15931600 args += [limit, start]
6969 ):
7070 self._known_servers_count = 1
7171 self.hs.get_clock().looping_call(
72 self._count_known_servers, 60 * 1000,
72 self._count_known_servers,
73 60 * 1000,
7374 )
7475 self.hs.get_clock().call_later(
75 1000, self._count_known_servers,
76 1000,
77 self._count_known_servers,
7678 )
7779 LaterGauge(
7880 "synapse_federation_known_servers",
173175
174176 @cached(max_entries=100000)
175177 async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
176 """ Get the details of a room roughly suitable for use by the room
178 """Get the details of a room roughly suitable for use by the room
177179 summary extension to /sync. Useful when lazy loading room members.
178180 Args:
179181 room_id: The room ID to query
487489 async def get_users_who_share_room_with_user(
488490 self, user_id: str, cache_context: _CacheContext
489491 ) -> Set[str]:
490 """Returns the set of users who share a room with `user_id`
491 """
492 """Returns the set of users who share a room with `user_id`"""
492493 room_ids = await self.get_rooms_for_user(
493494 user_id, on_invalidate=cache_context.invalidate
494495 )
617618 raise NotImplementedError()
618619
619620 @cachedList(
620 cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
621 cached_method_name="_get_joined_profile_from_event_id",
622 list_name="event_ids",
621623 )
622624 async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
623625 """For given set of member event_ids check if they point to a join
801803 async def get_membership_from_event_ids(
802804 self, member_event_ids: Iterable[str]
803805 ) -> List[dict]:
804 """Get user_id and membership of a set of event IDs.
805 """
806 """Get user_id and membership of a set of event IDs."""
806807
807808 return await self.db_pool.simple_select_many_batch(
808809 table="room_memberships",
2222
2323 def run_upgrade(cur, database_engine, *args, **kwargs):
2424 cur.execute(
25 "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
25 "UPDATE remote_media_cache SET last_access_ts = ?",
26 (int(time.time() * 1000),),
2627 )
6666 CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
6767 CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value )
6868 /* event_search(event_id,room_id,sender,"key",value) */;
69 CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value');
70 CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB);
71 CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));
72 CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB);
73 CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB);
7469 CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) );
7570 CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) );
7671 CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) );
148143 CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') );
149144 CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value )
150145 /* user_directory_search(user_id,value) */;
151 CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value');
152 CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB);
153 CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));
154 CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB);
155 CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB);
156146 CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL );
157147 CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);
158148 CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT );
5151
5252 # this inherits from EventsWorkerStore because it calls self.get_events
5353 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
54 """The parts of StateGroupStore that can be called from workers.
55 """
54 """The parts of StateGroupStore that can be called from workers."""
5655
5756 def __init__(self, database: DatabasePool, db_conn, hs):
5857 super().__init__(database, db_conn, hs)
275274 num_args=1,
276275 )
277276 async def _get_state_group_for_events(self, event_ids):
278 """Returns mapping event_id -> state_group
279 """
277 """Returns mapping event_id -> state_group"""
280278 rows = await self.db_pool.simple_select_many_batch(
281279 table="event_to_state_groups",
282280 column="event_id",
337335 columns=["state_group"],
338336 )
339337 self.db_pool.updates.register_background_update_handler(
340 self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
338 self.DELETE_CURRENT_STATE_UPDATE_NAME,
339 self._background_remove_left_rooms,
341340 )
342341
343342 async def _background_remove_left_rooms(self, progress, batch_size):
486485
487486
488487 class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
489 """ Keeps track of the state at a given event.
488 """Keeps track of the state at a given event.
490489
491490 This is done by the concept of `state groups`. Every event is a assigned
492491 a state group (identified by an arbitrary string), which references a
6363 def get_current_state_deltas_txn(txn):
6464 # First we calculate the max stream id that will give us less than
6565 # N results.
66 # We arbitarily limit to 100 stream_id entries to ensure we don't
66 # We arbitrarily limit to 100 stream_id entries to ensure we don't
6767 # select toooo many.
6868 sql = """
6969 SELECT stream_id, count(*)
8080 for stream_id, count in txn:
8181 total += count
8282 if total > 100:
83 # We arbitarily limit to 100 entries to ensure we don't
83 # We arbitrarily limit to 100 entries to ensure we don't
8484 # select toooo many.
8585 logger.debug(
8686 "Clipping current_state_delta_stream rows to stream_id %i",
10001000 ORDER BY {order_by_column} {order}
10011001 LIMIT ? OFFSET ?
10021002 """.format(
1003 sql_base=sql_base, order_by_column=order_by_column, order=order,
1003 sql_base=sql_base,
1004 order_by_column=order_by_column,
1005 order=order,
10041006 )
10051007
10061008 args += [limit, start]
564564 AND e.stream_ordering > ? AND e.stream_ordering <= ?
565565 ORDER BY e.stream_ordering ASC
566566 """
567 txn.execute(sql, (user_id, min_from_id, max_to_id,))
567 txn.execute(
568 sql,
569 (
570 user_id,
571 min_from_id,
572 max_to_id,
573 ),
574 )
568575
569576 rows = [
570577 _EventDictReturn(event_id, None, stream_ordering)
694701 return "t%d-%d" % (topo, token)
695702
696703 def get_stream_id_for_event_txn(
697 self, txn: LoggingTransaction, event_id: str, allow_none=False,
704 self,
705 txn: LoggingTransaction,
706 event_id: str,
707 allow_none=False,
698708 ) -> int:
699709 return self.db_pool.simple_select_one_onecol_txn(
700710 txn=txn,
705715 )
706716
707717 async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
708 """Get the persisted position for an event
709 """
718 """Get the persisted position for an event"""
710719 row = await self.db_pool.simple_select_one(
711720 table="events",
712721 keyvalues={"event_id": event_id},
896905 ) -> Tuple[int, List[EventBase]]:
897906 """Get all new events
898907
899 Returns all events with from_id < stream_ordering <= current_id.
900
901 Args:
902 from_id: the stream_ordering of the last event we processed
903 current_id: the stream_ordering of the most recently processed event
904 limit: the maximum number of events to return
905
906 Returns:
907 A tuple of (next_id, events), where `next_id` is the next value to
908 pass as `from_id` (it will either be the stream_ordering of the
909 last returned event, or, if fewer than `limit` events were found,
910 the `current_id`).
911 """
908 Returns all events with from_id < stream_ordering <= current_id.
909
910 Args:
911 from_id: the stream_ordering of the last event we processed
912 current_id: the stream_ordering of the most recently processed event
913 limit: the maximum number of events to return
914
915 Returns:
916 A tuple of (next_id, events), where `next_id` is the next value to
917 pass as `from_id` (it will either be the stream_ordering of the
918 last returned event, or, if fewer than `limit` events were found,
919 the `current_id`).
920 """
912921
913922 def get_all_new_events_stream_txn(txn):
914923 sql = (
12371246
12381247 @cached()
12391248 async def get_id_for_instance(self, instance_name: str) -> int:
1240 """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
1241 """
1249 """Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
12421250
12431251 def _get_id_for_instance_txn(txn):
12441252 instance_id = self.db_pool.simple_select_one_onecol_txn(
6363
6464
6565 class TransactionStore(TransactionWorkerStore):
66 """A collection of queries for handling PDUs.
67 """
66 """A collection of queries for handling PDUs."""
6867
6968 def __init__(self, database: DatabasePool, db_conn, hs):
7069 super().__init__(database, db_conn, hs)
197196 retry_interval: int,
198197 ) -> None:
199198 """Sets the current retry timings for a given destination.
200 Both timings should be zero if retrying is no longer occuring.
199 Both timings should be zero if retrying is no longer occurring.
201200
202201 Args:
203202 destination
298297 )
299298
300299 async def store_destination_rooms_entries(
301 self, destinations: Iterable[str], room_id: str, stream_ordering: int,
300 self,
301 destinations: Iterable[str],
302 room_id: str,
303 stream_ordering: int,
302304 ) -> None:
303305 """
304306 Updates or creates `destination_rooms` entries in batch for a single event.
393395 )
394396
395397 async def get_catch_up_room_event_ids(
396 self, destination: str, last_successful_stream_ordering: int,
398 self,
399 destination: str,
400 last_successful_stream_ordering: int,
397401 ) -> List[str]:
398402 """
399403 Returns at most 50 event IDs and their corresponding stream_orderings
417421
418422 @staticmethod
419423 def _get_catch_up_room_event_ids_txn(
420 txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
424 txn: LoggingTransaction,
425 destination: str,
426 last_successful_stream_ordering: int,
421427 ) -> List[str]:
422428 q = """
423429 SELECT event_id FROM destination_rooms
428434 LIMIT 50
429435 """
430436 txn.execute(
431 q, (destination, last_successful_stream_ordering),
437 q,
438 (destination, last_successful_stream_ordering),
432439 )
433440 event_ids = [row[0] for row in txn]
434441 return event_ids
4343 """
4444
4545 async def create_ui_auth_session(
46 self, clientdict: JsonDict, uri: str, method: str, description: str,
46 self,
47 clientdict: JsonDict,
48 uri: str,
49 method: str,
50 description: str,
4751 ) -> UIAuthSessionData:
4852 """
4953 Creates a new user interactive authentication session.
122126 return UIAuthSessionData(session_id, **result)
123127
124128 async def mark_ui_auth_stage_complete(
125 self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
129 self,
130 session_id: str,
131 stage_type: str,
132 result: Union[str, bool, JsonDict],
126133 ):
127134 """
128135 Mark a session stage as completed.
260267 return serverdict.get(key, default)
261268
262269 async def add_user_agent_ip_to_ui_auth_session(
263 self, session_id: str, user_agent: str, ip: str,
270 self,
271 session_id: str,
272 user_agent: str,
273 ip: str,
264274 ):
265 """Add the given user agent / IP to the tracking table
266 """
275 """Add the given user agent / IP to the tracking table"""
267276 await self.db_pool.simple_upsert(
268277 table="ui_auth_sessions_ips",
269278 keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
272281 )
273282
274283 async def get_user_agents_ips_to_ui_auth_session(
275 self, session_id: str,
284 self,
285 session_id: str,
276286 ) -> List[Tuple[str, str]]:
277287 """Get the given user agents / IPs used during the ui auth process
278288
335335 return len(users_to_work_on)
336336
337337 async def is_room_world_readable_or_publicly_joinable(self, room_id):
338 """Check if the room is either world_readable or publically joinable
339 """
338 """Check if the room is either world_readable or publically joinable"""
340339
341340 # Create a state filter that only queries join and history state event
342341 types_to_filter = (
515514 )
516515
517516 async def delete_all_from_user_dir(self) -> None:
518 """Delete the entire user directory
519 """
517 """Delete the entire user directory"""
520518
521519 def _delete_all_from_user_dir_txn(txn):
522520 txn.execute("DELETE FROM user_directory")
708706
709707 return {row["room_id"] for row in rows}
710708
711 async def get_user_directory_stream_pos(self) -> int:
709 async def get_user_directory_stream_pos(self) -> Optional[int]:
710 """
711 Get the stream ID of the user directory stream.
712
713 Returns:
714 The stream token or None if the initial background update hasn't happened yet.
715 """
712716 return await self.db_pool.simple_select_one_onecol(
713717 table="user_directory_stream_pos",
714718 keyvalues={},
2626
2727
2828 class StateGroupBackgroundUpdateStore(SQLBaseStore):
29 """Defines functions related to state groups needed to run the state backgroud
29 """Defines functions related to state groups needed to run the state background
3030 updates.
3131 """
3232
4747
4848
4949 class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
50 """A data store for fetching/storing state groups.
51 """
50 """A data store for fetching/storing state groups."""
5251
5352 def __init__(self, database: DatabasePool, db_conn, hs):
5453 super().__init__(database, db_conn, hs)
8887 50000,
8988 )
9089 self._state_group_members_cache = DictionaryCache(
91 "*stateGroupMembersCache*", 500000,
90 "*stateGroupMembersCache*",
91 500000,
9292 )
9393
9494 def get_max_state_group_txn(txn: Cursor):
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import platform
1514
1615 from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
1716 from .postgres import PostgresEngine
2726 return Sqlite3Engine(sqlite3, database_config)
2827
2928 if name == "psycopg2":
30 # pypy requires psycopg2cffi rather than psycopg2
31 if platform.python_implementation() == "PyPy":
32 import psycopg2cffi as psycopg2 # type: ignore
33 else:
34 import psycopg2 # type: ignore
29 # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
30 import psycopg2 # type: ignore
3531
3632 return PostgresEngine(psycopg2, database_config)
3733
9393 @property
9494 @abc.abstractmethod
9595 def server_version(self) -> str:
96 """Gets a string giving the server version. For example: '3.22.0'
97 """
96 """Gets a string giving the server version. For example: '3.22.0'"""
9897 ...
9998
10099 @abc.abstractmethod
101100 def in_transaction(self, conn: Connection) -> bool:
102 """Whether the connection is currently in a transaction.
103 """
101 """Whether the connection is currently in a transaction."""
104102 ...
105103
106104 @abc.abstractmethod
137137
138138 @property
139139 def supports_using_any_list(self):
140 """Do we support using `a = ANY(?)` and passing a list
141 """
140 """Do we support using `a = ANY(?)` and passing a list"""
142141 return True
143142
144143 def is_deadlock(self, error):
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 import platform
1415 import struct
1516 import threading
1617 import typing
2728 super().__init__(database_module, database_config)
2829
2930 database = database_config.get("args", {}).get("database")
30 self._is_in_memory = database in (None, ":memory:",)
31 self._is_in_memory = database in (
32 None,
33 ":memory:",
34 )
35
36 if platform.python_implementation() == "PyPy":
37 # pypy's sqlite3 module doesn't handle bytearrays, convert them
38 # back to bytes.
39 database_module.register_adapter(bytearray, lambda array: bytes(array))
3140
3241 # The current max state_group, or None if we haven't looked
3342 # in the DB yet.
5665
5766 @property
5867 def supports_using_any_list(self):
59 """Do we support using `a = ANY(?)` and passing a list
60 """
68 """Do we support using `a = ANY(?)` and passing a list"""
6169 return False
6270
6371 def check_database(self, db_conn, allow_outdated_version: bool = False):
410410 )
411411
412412 for room_id, ev_ctx_rm in events_by_room.items():
413 latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
414 room_id
413 latest_event_ids = (
414 await self.main_store.get_latest_event_ids_in_room(room_id)
415415 )
416416 new_latest_event_ids = await self._calculate_new_extremities(
417417 room_id, ev_ctx_rm, latest_event_ids
888888 continue
889889
890890 logger.debug(
891 "Not dropping as too new and not in new_senders: %s", new_senders,
891 "Not dropping as too new and not in new_senders: %s",
892 new_senders,
892893 )
893894
894895 return new_latest_event_ids
10031004
10041005 remote_event_ids = [
10051006 event_id
1006 for (typ, state_key,), event_id in current_state.items()
1007 for (
1008 typ,
1009 state_key,
1010 ), event_id in current_state.items()
10071011 if typ == EventTypes.Member and not self.is_mine_id(state_key)
10081012 ]
10091013 rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
112112 # which should be empty.
113113 if config is None:
114114 raise ValueError(
115 "config==None in prepare_database, but databse is not empty"
115 "config==None in prepare_database, but database is not empty"
116116 )
117117
118118 # if it's a worker app, refuse to upgrade the database, to avoid multiple
424424 # We don't support using the same file name in the same delta version.
425425 raise PrepareDatabaseException(
426426 "Found multiple delta files with the same name in v%d: %s"
427 % (v, duplicates,)
427 % (
428 v,
429 duplicates,
430 )
428431 )
429432
430433 # We sort to ensure that we apply the delta files in a consistent
531534 names_and_streams: the names and streams of schemas to be applied
532535 """
533536 cur.execute(
534 "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
537 "SELECT file FROM applied_module_schemas WHERE module_name = ?",
538 (modname,),
535539 )
536540 applied_deltas = {d for d, in cur}
537541 for (name, stream) in names_and_streams:
618622
619623 txn.execute("SELECT version, upgraded FROM schema_version")
620624 row = txn.fetchone()
621 current_version = int(row[0]) if row else None
622
623 if current_version:
625
626 if row is not None:
627 current_version = int(row[0])
624628 txn.execute(
625629 "SELECT file FROM applied_schema_deltas WHERE version >= ?",
626630 (current_version,),
2525
2626
2727 class PurgeEventsStorage:
28 """High level interface for purging rooms and event history.
29 """
28 """High level interface for purging rooms and event history."""
3029
3130 def __init__(self, hs: "HomeServer", stores: Databases):
3231 self.stores = stores
3332
3433 async def purge_room(self, room_id: str) -> None:
35 """Deletes all record of a room
36 """
34 """Deletes all record of a room"""
3735
3836 state_groups_to_delete = await self.stores.main.purge_room(room_id)
3937 await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
339339
340340
341341 class StateGroupStorage:
342 """High level interface to fetching state for event.
343 """
342 """High level interface to fetching state for event."""
344343
345344 def __init__(self, hs: "HomeServer", stores: "Databases"):
346345 self.stores = stores
399398 async def get_state_groups(
400399 self, room_id: str, event_ids: Iterable[str]
401400 ) -> Dict[int, List[EventBase]]:
402 """ Get the state groups for the given list of event_ids
401 """Get the state groups for the given list of event_ids
403402
404403 Args:
405404 room_id: ID of the room for these events.
1111 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
14 from typing import Any, Iterable, Iterator, List, Optional, Tuple
14 from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
1515
1616 from typing_extensions import Protocol
1717
1919 Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
2020 """
2121
22 _Parameters = Union[Sequence[Any], Mapping[str, Any]]
23
2224
2325 class Cursor(Protocol):
24 def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
26 def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
2527 ...
2628
27 def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
29 def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
30 ...
31
32 def fetchone(self) -> Optional[Tuple]:
33 ...
34
35 def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
2836 ...
2937
3038 def fetchall(self) -> List[Tuple]:
3139 ...
3240
33 def fetchone(self) -> Tuple:
41 @property
42 def description(
43 self,
44 ) -> Optional[
45 Sequence[
46 # Note that this is an approximate typing based on sqlite3 and other
47 # drivers, and may not be entirely accurate.
48 Tuple[
49 str,
50 Optional[Any],
51 Optional[int],
52 Optional[int],
53 Optional[int],
54 Optional[int],
55 Optional[int],
56 ]
57 ]
58 ]:
3459 ...
35
36 @property
37 def description(self) -> Any:
38 return None
3960
4061 @property
4162 def rowcount(self) -> int:
5879 def commit(self) -> None:
5980 ...
6081
61 def rollback(self, *args, **kwargs) -> None:
82 def rollback(self) -> None:
6283 ...
6384
6485 def __enter__(self) -> "Connection":
244244 # and b) noting that if we have seen a run of persisted positions
245245 # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
246246 #
247 # Note: There is no guarentee that the IDs generated by the sequence
247 # Note: There is no guarantee that the IDs generated by the sequence
248248 # will be gapless; gaps can form when e.g. a transaction was rolled
249249 # back. This means that sometimes we won't be able to skip forward the
250250 # position even though everything has been persisted. However, since
276276 self._load_current_ids(db_conn, tables)
277277
278278 def _load_current_ids(
279 self, db_conn, tables: List[Tuple[str, str, str]],
279 self,
280 db_conn,
281 tables: List[Tuple[str, str, str]],
280282 ):
281283 cur = db_conn.cursor(txn_name="_load_current_ids")
282284
363365 rows.sort()
364366
365367 with self._lock:
366 for (instance, stream_id,) in rows:
368 for (
369 instance,
370 stream_id,
371 ) in rows:
367372 stream_id = self._return_factor * stream_id
368373 self._add_persisted_position(stream_id)
369374
417422 # bother, as nothing will read it).
418423 #
419424 # We only do this on the success path so that the persisted current
420 # position points to a persited row with the correct instance name.
425 # position points to a persisted row with the correct instance name.
421426 if self._writers:
422427 txn.call_after(
423428 run_as_background_process,
480485 return self.get_persisted_upto_position()
481486
482487 def get_current_token_for_writer(self, instance_name: str) -> int:
483 """Returns the position of the given writer.
484 """
488 """Returns the position of the given writer."""
485489
486490 # If we don't have an entry for the given instance name, we assume it's a
487491 # new writer.
508512 }
509513
510514 def advance(self, instance_name: str, new_id: int):
511 """Advance the postion of the named writer to the given ID, if greater
515 """Advance the position of the named writer to the given ID, if greater
512516 than existing entry.
513517 """
514518
580584 break
581585
582586 def _update_stream_positions_table_txn(self, txn: Cursor):
583 """Update the `stream_positions` table with newly persisted position.
584 """
587 """Update the `stream_positions` table with newly persisted position."""
585588
586589 if not self._writers:
587590 return
621624
622625 @attr.s(slots=True)
623626 class _MultiWriterCtxManager:
624 """Async context manager returned by MultiWriterIdGenerator
625 """
627 """Async context manager returned by MultiWriterIdGenerator"""
626628
627629 id_gen = attr.ib(type=MultiWriterIdGenerator)
628630 multiple_ids = attr.ib(type=Optional[int], default=None)
105105
106106 def get_next_id_txn(self, txn: Cursor) -> int:
107107 txn.execute("SELECT nextval(?)", (self._sequence_name,))
108 return txn.fetchone()[0]
108 fetch_res = txn.fetchone()
109 assert fetch_res is not None
110 return fetch_res[0]
109111
110112 def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
111113 txn.execute(
121123 stream_name: Optional[str] = None,
122124 positive: bool = True,
123125 ):
124 """See SequenceGenerator.check_consistency for docstring.
125 """
126 """See SequenceGenerator.check_consistency for docstring."""
126127
127128 txn = db_conn.cursor(txn_name="sequence.check_consistency")
128129
146147 txn.execute(
147148 "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
148149 )
149 last_value, is_called = txn.fetchone()
150 fetch_res = txn.fetchone()
151 assert fetch_res is not None
152 last_value, is_called = fetch_res
150153
151154 # If we have an associated stream check the stream_positions table.
152155 max_in_stream_positions = None
468468 )
469469
470470 def __attrs_post_init__(self):
471 """Validates that both `topological` and `instance_map` aren't set.
472 """
471 """Validates that both `topological` and `instance_map` aren't set."""
473472
474473 if self.instance_map and self.topological:
475474 raise ValueError(
497496 instance_name = await store.get_name_from_instance_id(instance_id)
498497 instance_map[instance_name] = pos
499498
500 return cls(topological=None, stream=stream, instance_map=instance_map,)
499 return cls(
500 topological=None,
501 stream=stream,
502 instance_map=instance_map,
503 )
501504 except Exception:
502505 pass
503506 raise SynapseError(400, "Invalid token %r" % (string,))
674677 persisted in the same room after this position will be after the
675678 returned `RoomStreamToken`.
676679
677 Note: no guarentees are made about ordering w.r.t. events in other
680 Note: no guarantees are made about ordering w.r.t. events in other
678681 rooms.
679682 """
680683 # Doing the naive thing satisfies the desired properties described in
251251 self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry]
252252
253253 def is_queued(self, key: Hashable) -> bool:
254 """Checks whether there is a process queued up waiting
255 """
254 """Checks whether there is a process queued up waiting"""
256255 entry = self.key_to_defer.get(key)
257256 if not entry:
258257 # No entry so nothing is waiting.
451450
452451
453452 def timeout_deferred(
454 deferred: defer.Deferred, timeout: float, reactor: IReactorTime,
453 deferred: defer.Deferred,
454 timeout: float,
455 reactor: IReactorTime,
455456 ) -> defer.Deferred:
456457 """The in built twisted `Deferred.addTimeout` fails to time out deferreds
457458 that have a canceller that throws exceptions. This method creates a new
496497 delayed_call = reactor.callLater(timeout, time_it_out)
497498
498499 def convert_cancelled(value: failure.Failure):
499 # if the orgininal deferred was cancelled, and our timeout has fired, then
500 # if the original deferred was cancelled, and our timeout has fired, then
500501 # the reason it was cancelled was due to our timeout. Turn the CancelledError
501502 # into a TimeoutError.
502503 if timed_out[0] and value.check(CancelledError):
528529
529530 @attr.s(slots=True, frozen=True)
530531 class DoneAwaitable:
531 """Simple awaitable that returns the provided value.
532 """
532 """Simple awaitable that returns the provided value."""
533533
534534 value = attr.ib()
535535
544544
545545
546546 def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
547 """Convert a value to an awaitable if not already an awaitable.
548 """
547 """Convert a value to an awaitable if not already an awaitable."""
549548 if inspect.isawaitable(value):
550549 assert isinstance(value, Awaitable)
551550 return value
148148
149149
150150 def intern_string(string):
151 """Takes a (potentially) unicode string and interns it if it's ascii
152 """
151 """Takes a (potentially) unicode string and interns it if it's ascii"""
153152 if string is None:
154153 return None
155154
160159
161160
162161 def intern_dict(dictionary):
163 """Takes a dictionary and interns well known keys and their values
164 """
162 """Takes a dictionary and interns well known keys and their values"""
165163 return {
166164 KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
167165 for key, value in dictionary.items()
0 # -*- coding: utf-8 -*-
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
16
17 from twisted.internet.defer import Deferred
18 from twisted.python.failure import Failure
19
20 from synapse.logging.context import make_deferred_yieldable, run_in_background
21
22 TV = TypeVar("TV")
23
24
25 class CachedCall(Generic[TV]):
26 """A wrapper for asynchronous calls whose results should be shared
27
28 This is useful for wrapping asynchronous functions, where there might be multiple
29 callers, but we only want to call the underlying function once (and have the result
30 returned to all callers).
31
32 Similar results can be achieved via a lock of some form, but that typically requires
33 more boilerplate (and ends up being less efficient).
34
35 Correctly handles Synapse logcontexts (logs and resource usage for the underlying
36 function are logged against the logcontext which is active when get() is first
37 called).
38
39 Example usage:
40
41 _cached_val = CachedCall(_load_prop)
42
43 async def handle_request() -> X:
44 # We can call this multiple times, but it will result in a single call to
45 # _load_prop().
46 return await _cached_val.get()
47
48 async def _load_prop() -> X:
49 await difficult_operation()
50
51
52 The implementation is deliberately single-shot (ie, once the call is initiated,
53 there is no way to ask for it to be run). This keeps the implementation and
54 semantics simple. If you want to make a new call, simply replace the whole
55 CachedCall object.
56 """
57
58 __slots__ = ["_callable", "_deferred", "_result"]
59
60 def __init__(self, f: Callable[[], Awaitable[TV]]):
61 """
62 Args:
63 f: The underlying function. Only one call to this function will be alive
64 at once (per instance of CachedCall)
65 """
66 self._callable = f # type: Optional[Callable[[], Awaitable[TV]]]
67 self._deferred = None # type: Optional[Deferred]
68 self._result = None # type: Union[None, Failure, TV]
69
70 async def get(self) -> TV:
71 """Kick off the call if necessary, and return the result"""
72
73 # Fire off the callable now if this is our first time
74 if not self._deferred:
75 self._deferred = run_in_background(self._callable)
76
77 # we will never need the callable again, so make sure it can be GCed
78 self._callable = None
79
80 # once the deferred completes, store the result. We cannot simply leave the
81 # result in the deferred, since if it's a Failure, GCing the deferred
82 # would then log a critical error about unhandled Failures.
83 def got_result(r):
84 self._result = r
85
86 self._deferred.addBoth(got_result)
87
88 # TODO: consider cancellation semantics. Currently, if the call to get()
89 # is cancelled, the underlying call will continue (and any future calls
90 # will get the result/exception), which I think is *probably* ok, modulo
91 # the fact the underlying call may be logged to a cancelled logcontext,
92 # and any eventual exception may not be reported.
93
94 # we can now await the deferred, and once it completes, return the result.
95 await make_deferred_yieldable(self._deferred)
96
97 # I *think* this is the easiest way to correctly raise a Failure without having
98 # to gut-wrench into the implementation of Deferred.
99 d = Deferred()
100 d.callback(self._result)
101 return await d
102
103
104 class RetryOnExceptionCachedCall(Generic[TV]):
105 """A wrapper around CachedCall which will retry the call if an exception is thrown
106
107 This is used in much the same way as CachedCall, but adds some extra functionality
108 so that if the underlying function throws an exception, then the next call to get()
109 will initiate another call to the underlying function. (Any calls to get() which
110 are already pending will raise the exception.)
111 """
112
113 slots = ["_cachedcall"]
114
115 def __init__(self, f: Callable[[], Awaitable[TV]]):
116 async def _wrapper() -> TV:
117 try:
118 return await f()
119 except Exception:
120 # the call raised an exception: replace the underlying CachedCall to
121 # trigger another call next time get() is called
122 self._cachedcall = CachedCall(_wrapper)
123 raise
124
125 self._cachedcall = CachedCall(_wrapper)
126
127 async def get(self) -> TV:
128 return await self._cachedcall.get()
121121
122122
123123 def lru_cache(
124 max_entries: int = 1000, cache_context: bool = False,
124 max_entries: int = 1000,
125 cache_context: bool = False,
125126 ) -> Callable[[F], _LruCachedFunction[F]]:
126127 """A method decorator that applies a memoizing cache around the function.
127128
155156
156157 def func(orig: F) -> _LruCachedFunction[F]:
157158 desc = LruCacheDescriptor(
158 orig, max_entries=max_entries, cache_context=cache_context,
159 orig,
160 max_entries=max_entries,
161 cache_context=cache_context,
159162 )
160163 return cast(_LruCachedFunction[F], desc)
161164
169172 sentinel = object()
170173
171174 def __init__(
172 self, orig, max_entries: int = 1000, cache_context: bool = False,
175 self,
176 orig,
177 max_entries: int = 1000,
178 cache_context: bool = False,
173179 ):
174180 super().__init__(orig, num_args=None, cache_context=cache_context)
175181 self.max_entries = max_entries
176182
177183 def __get__(self, obj, owner):
178184 cache = LruCache(
179 cache_name=self.orig.__name__, max_size=self.max_entries,
185 cache_name=self.orig.__name__,
186 max_size=self.max_entries,
180187 ) # type: LruCache[CacheKey, Any]
181188
182189 get_cache_key = self.cache_key_builder
211218
212219
213220 class DeferredCacheDescriptor(_CacheDescriptorBase):
214 """ A method decorator that applies a memoizing cache around the function.
221 """A method decorator that applies a memoizing cache around the function.
215222
216223 This caches deferreds, rather than the results themselves. Deferreds that
217224 fail are removed from the cache.
8383 return False
8484
8585 def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
86 """Returns True if the entity may have been updated since stream_pos
87 """
86 """Returns True if the entity may have been updated since stream_pos"""
8887 assert isinstance(stream_pos, int)
8988
9089 if stream_pos < self._earliest_known_stream_pos:
132131 return result
133132
134133 def has_any_entity_changed(self, stream_pos: int) -> bool:
135 """Returns if any entity has changed
136 """
134 """Returns if any entity has changed"""
137135 assert type(stream_pos) is int
138136
139137 if not self._cache:
107107 return await maybe_awaitable(observer(*args, **kwargs))
108108 except Exception as e:
109109 logger.warning(
110 "%s signal observer %s failed: %r", self.name, observer, e,
110 "%s signal observer %s failed: %r",
111 self.name,
112 observer,
113 e,
111114 )
112115
113116 deferreds = [run_in_background(do, o) for o in self.observers]
8282 self._producer.resumeProducing()
8383
8484 def unregisterProducer(self):
85 """Part of IProducer interface
86 """
85 """Part of IProducer interface"""
8786 self._producer = None
8887 if not self._finished_deferred.called:
8988 self._bytes_queue.put_nowait(None)
9089
9190 def write(self, bytes):
92 """Part of IProducer interface
93 """
91 """Part of IProducer interface"""
9492 if self._write_exception:
9593 raise self._write_exception
9694
106104 self._producer.pauseProducing()
107105
108106 def _writer(self):
109 """This is run in a background thread to write to the file.
110 """
107 """This is run in a background thread to write to the file."""
111108 try:
112109 while self._producer or not self._bytes_queue.empty():
113110 # If we've paused the producer check if we should resume the
134131 self._file_obj.close()
135132
136133 def wait(self):
137 """Returns a deferred that resolves when finished writing to file
138 """
134 """Returns a deferred that resolves when finished writing to file"""
139135 return make_deferred_yieldable(self._finished_deferred)
140136
141137 def _resume_paused_producer(self):
142 """Gets called if we should resume producing after being paused
143 """
138 """Gets called if we should resume producing after being paused"""
144139 if self._paused_producer and self._producer:
145140 self._paused_producer = False
146141 self._producer.resumeProducing()
6161
6262
6363 def sorted_topologically(
64 nodes: Iterable[T], graph: Mapping[T, Collection[T]],
64 nodes: Iterable[T],
65 graph: Mapping[T, Collection[T]],
6566 ) -> Generator[T, None, None]:
6667 """Given a set of nodes and a graph, yield the nodes in toplogical order.
6768
1414
1515
1616 class JsonEncodedObject:
17 """ A common base class for defining protocol units that are represented
17 """A common base class for defining protocol units that are represented
1818 as JSON.
1919
2020 Attributes:
3838 """
3939
4040 def __init__(self, **kwargs):
41 """ Takes the dict of `kwargs` and loads all keys that are *valid*
41 """Takes the dict of `kwargs` and loads all keys that are *valid*
4242 (i.e., are included in the `valid_keys` list) into the dictionary`
4343 instance variable.
4444
6060 self.unrecognized_keys[k] = v
6161
6262 def get_dict(self):
63 """ Converts this protocol unit into a :py:class:`dict`, ready to be
63 """Converts this protocol unit into a :py:class:`dict`, ready to be
6464 encoded as JSON.
6565
6666 The keys it encodes are: `valid_keys` - `internal_keys`
160160 return self._logging_context.get_resource_usage()
161161
162162 def _update_in_flight(self, metrics):
163 """Gets called when processing in flight metrics
164 """
163 """Gets called when processing in flight metrics"""
165164 duration = self.clock.time() - self.start
166165
167166 metrics.real_time_max = max(metrics.real_time_max, duration)
2424
2525
2626 def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
27 """ Loads a synapse module with its config
27 """Loads a synapse module with its config
2828
2929 Args:
3030 provider: a dict with keys 'module' (the module name) and 'config'
203203 # We don't raise here as its perfectly valid for contexts to
204204 # change in a function, as long as it sets the correct context
205205 # on resolving (which is checked separately).
206 err = (
207 "%s changed context from %s to %s, happened between lines %d and %d in %s"
208 % (
209 frame.f_code.co_name,
210 expected_context,
211 current_context(),
212 last_yield_line_no,
213 frame.f_lineno,
214 frame.f_code.co_filename,
215 )
206 err = "%s changed context from %s to %s, happened between lines %d and %d in %s" % (
207 frame.f_code.co_name,
208 expected_context,
209 current_context(),
210 last_yield_line_no,
211 frame.f_lineno,
212 frame.f_code.co_filename,
216213 )
217214 changes.append(err)
218215
2424 _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
2525
2626 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
27 client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
27 CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
2828
2929 # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
3030 # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
4141 rand = random.SystemRandom()
4242
4343
44 def random_string(length):
44 def random_string(length: int) -> str:
4545 return "".join(rand.choice(string.ascii_letters) for _ in range(length))
4646
4747
48 def random_string_with_symbols(length):
48 def random_string_with_symbols(length: int) -> str:
4949 return "".join(rand.choice(_string_with_symbols) for _ in range(length))
5050
5151
52 def is_ascii(s):
53 if isinstance(s, bytes):
54 try:
55 s.decode("ascii").encode("ascii")
56 except UnicodeDecodeError:
57 return False
58 except UnicodeEncodeError:
59 return False
60 return True
52 def is_ascii(s: bytes) -> bool:
53 try:
54 s.decode("ascii").encode("ascii")
55 except UnicodeDecodeError:
56 return False
57 except UnicodeEncodeError:
58 return False
59 return True
6160
6261
63 def assert_valid_client_secret(client_secret):
64 """Validate that a given string matches the client_secret regex defined by the spec"""
65 if client_secret_regex.match(client_secret) is None:
62 def assert_valid_client_secret(client_secret: str) -> None:
63 """Validate that a given string matches the client_secret defined by the spec"""
64 if (
65 len(client_secret) <= 0
66 or len(client_secret) > 255
67 or CLIENT_SECRET_REGEX.match(client_secret) is None
68 ):
6669 raise SynapseError(
6770 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
6871 )
7979 events = [e for e in events if not e.internal_metadata.is_soft_failed()]
8080
8181 types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
82
8283 event_id_to_state = await storage.state.get_state_for_events(
8384 frozenset(e.event_id for e in events),
8485 state_filter=StateFilter.from_types(types),
232233
233234 elif visibility == HistoryVisibility.SHARED and is_peeking:
234235 # if the visibility is shared, users cannot see the event unless
235 # they have *subequently* joined the room (or were members at the
236 # they have *subsequently* joined the room (or were members at the
236237 # time, of course)
237238 #
238239 # XXX: if the user has subsequently joined and then left again,
9595 runner.args.loops = orig_loops
9696 loops = "auto"
9797 runner.bench_time_func(
98 suite.__name__ + "_" + str(loops), make_test(suite.main),
98 suite.__name__ + "_" + str(loops),
99 make_test(suite.main),
99100 )
9797
9898 logger = logging.getLogger("synapse.logging.test_terse_json")
9999 _setup_stdlib_logging(
100 hs_config, log_config, logBeginner=beginner,
100 hs_config,
101 log_config,
102 logBeginner=beginner,
101103 )
102104
103105 # Wait for it to connect...
1616
1717 import pymacaroons
1818
19 from twisted.internet import defer
20
2119 from synapse.api.auth import Auth
2220 from synapse.api.constants import UserTypes
2321 from synapse.api.errors import (
3230 from synapse.types import UserID
3331
3432 from tests import unittest
35 from tests.utils import mock_getRawHeaders, setup_test_homeserver
36
37
38 class AuthTestCase(unittest.TestCase):
39 @defer.inlineCallbacks
40 def setUp(self):
41 self.state_handler = Mock()
33 from tests.test_utils import simple_async_mock
34 from tests.utils import mock_getRawHeaders
35
36
37 class AuthTestCase(unittest.HomeserverTestCase):
38 def prepare(self, reactor, clock, hs):
4239 self.store = Mock()
4340
44 self.hs = yield setup_test_homeserver(self.addCleanup)
45 self.hs.get_datastore = Mock(return_value=self.store)
46 self.hs.get_auth_handler().store = self.store
47 self.auth = Auth(self.hs)
41 hs.get_datastore = Mock(return_value=self.store)
42 hs.get_auth_handler().store = self.store
43 self.auth = Auth(hs)
4844
4945 # AuthBlocking reads from the hs' config on initialization. We need to
5046 # modify its config instead of the hs'
5652 # this is overridden for the appservice tests
5753 self.store.get_app_service_by_token = Mock(return_value=None)
5854
59 self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
60 self.store.is_support_user = Mock(return_value=defer.succeed(False))
61
62 @defer.inlineCallbacks
55 self.store.insert_client_ip = simple_async_mock(None)
56 self.store.is_support_user = simple_async_mock(False)
57
6358 def test_get_user_by_req_user_valid_token(self):
6459 user_info = TokenLookupResult(
6560 user_id=self.test_user, token_id=5, device_id="device"
6661 )
67 self.store.get_user_by_access_token = Mock(
68 return_value=defer.succeed(user_info)
69 )
70
71 request = Mock(args={})
72 request.args[b"access_token"] = [self.test_token]
73 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
74 requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
62 self.store.get_user_by_access_token = simple_async_mock(user_info)
63
64 request = Mock(args={})
65 request.args[b"access_token"] = [self.test_token]
66 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
67 requester = self.get_success(self.auth.get_user_by_req(request))
7568 self.assertEquals(requester.user.to_string(), self.test_user)
7669
7770 def test_get_user_by_req_user_bad_token(self):
78 self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
79
80 request = Mock(args={})
81 request.args[b"access_token"] = [self.test_token]
82 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
83 d = defer.ensureDeferred(self.auth.get_user_by_req(request))
84 f = self.failureResultOf(d, InvalidClientTokenError).value
71 self.store.get_user_by_access_token = simple_async_mock(None)
72
73 request = Mock(args={})
74 request.args[b"access_token"] = [self.test_token]
75 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
76 f = self.get_failure(
77 self.auth.get_user_by_req(request), InvalidClientTokenError
78 ).value
8579 self.assertEqual(f.code, 401)
8680 self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
8781
8882 def test_get_user_by_req_user_missing_token(self):
8983 user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
90 self.store.get_user_by_access_token = Mock(
91 return_value=defer.succeed(user_info)
92 )
93
94 request = Mock(args={})
95 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
96 d = defer.ensureDeferred(self.auth.get_user_by_req(request))
97 f = self.failureResultOf(d, MissingClientTokenError).value
84 self.store.get_user_by_access_token = simple_async_mock(user_info)
85
86 request = Mock(args={})
87 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
88 f = self.get_failure(
89 self.auth.get_user_by_req(request), MissingClientTokenError
90 ).value
9891 self.assertEqual(f.code, 401)
9992 self.assertEqual(f.errcode, "M_MISSING_TOKEN")
10093
101 @defer.inlineCallbacks
10294 def test_get_user_by_req_appservice_valid_token(self):
10395 app_service = Mock(
10496 token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
10597 )
10698 self.store.get_app_service_by_token = Mock(return_value=app_service)
107 self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
99 self.store.get_user_by_access_token = simple_async_mock(None)
108100
109101 request = Mock(args={})
110102 request.getClientIP.return_value = "127.0.0.1"
111103 request.args[b"access_token"] = [self.test_token]
112104 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
113 requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
105 requester = self.get_success(self.auth.get_user_by_req(request))
114106 self.assertEquals(requester.user.to_string(), self.test_user)
115107
116 @defer.inlineCallbacks
117108 def test_get_user_by_req_appservice_valid_token_good_ip(self):
118109 from netaddr import IPSet
119110
124115 ip_range_whitelist=IPSet(["192.168/16"]),
125116 )
126117 self.store.get_app_service_by_token = Mock(return_value=app_service)
127 self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
118 self.store.get_user_by_access_token = simple_async_mock(None)
128119
129120 request = Mock(args={})
130121 request.getClientIP.return_value = "192.168.10.10"
131122 request.args[b"access_token"] = [self.test_token]
132123 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
133 requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
124 requester = self.get_success(self.auth.get_user_by_req(request))
134125 self.assertEquals(requester.user.to_string(), self.test_user)
135126
136127 def test_get_user_by_req_appservice_valid_token_bad_ip(self):
143134 ip_range_whitelist=IPSet(["192.168/16"]),
144135 )
145136 self.store.get_app_service_by_token = Mock(return_value=app_service)
146 self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
137 self.store.get_user_by_access_token = simple_async_mock(None)
147138
148139 request = Mock(args={})
149140 request.getClientIP.return_value = "131.111.8.42"
150141 request.args[b"access_token"] = [self.test_token]
151142 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
152 d = defer.ensureDeferred(self.auth.get_user_by_req(request))
153 f = self.failureResultOf(d, InvalidClientTokenError).value
143 f = self.get_failure(
144 self.auth.get_user_by_req(request), InvalidClientTokenError
145 ).value
154146 self.assertEqual(f.code, 401)
155147 self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
156148
157149 def test_get_user_by_req_appservice_bad_token(self):
158150 self.store.get_app_service_by_token = Mock(return_value=None)
159 self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
160
161 request = Mock(args={})
162 request.args[b"access_token"] = [self.test_token]
163 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
164 d = defer.ensureDeferred(self.auth.get_user_by_req(request))
165 f = self.failureResultOf(d, InvalidClientTokenError).value
151 self.store.get_user_by_access_token = simple_async_mock(None)
152
153 request = Mock(args={})
154 request.args[b"access_token"] = [self.test_token]
155 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
156 f = self.get_failure(
157 self.auth.get_user_by_req(request), InvalidClientTokenError
158 ).value
166159 self.assertEqual(f.code, 401)
167160 self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
168161
169162 def test_get_user_by_req_appservice_missing_token(self):
170163 app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
171164 self.store.get_app_service_by_token = Mock(return_value=app_service)
172 self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
173
174 request = Mock(args={})
175 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
176 d = defer.ensureDeferred(self.auth.get_user_by_req(request))
177 f = self.failureResultOf(d, MissingClientTokenError).value
165 self.store.get_user_by_access_token = simple_async_mock(None)
166
167 request = Mock(args={})
168 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
169 f = self.get_failure(
170 self.auth.get_user_by_req(request), MissingClientTokenError
171 ).value
178172 self.assertEqual(f.code, 401)
179173 self.assertEqual(f.errcode, "M_MISSING_TOKEN")
180174
181 @defer.inlineCallbacks
182175 def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
183176 masquerading_user_id = b"@doppelganger:matrix.org"
184177 app_service = Mock(
187180 app_service.is_interested_in_user = Mock(return_value=True)
188181 self.store.get_app_service_by_token = Mock(return_value=app_service)
189182 # This just needs to return a truth-y value.
190 self.store.get_user_by_id = Mock(
191 return_value=defer.succeed({"is_guest": False})
192 )
193 self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
183 self.store.get_user_by_id = simple_async_mock({"is_guest": False})
184 self.store.get_user_by_access_token = simple_async_mock(None)
194185
195186 request = Mock(args={})
196187 request.getClientIP.return_value = "127.0.0.1"
197188 request.args[b"access_token"] = [self.test_token]
198189 request.args[b"user_id"] = [masquerading_user_id]
199190 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
200 requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
191 requester = self.get_success(self.auth.get_user_by_req(request))
201192 self.assertEquals(
202193 requester.user.to_string(), masquerading_user_id.decode("utf8")
203194 )
209200 )
210201 app_service.is_interested_in_user = Mock(return_value=False)
211202 self.store.get_app_service_by_token = Mock(return_value=app_service)
212 self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
203 self.store.get_user_by_access_token = simple_async_mock(None)
213204
214205 request = Mock(args={})
215206 request.getClientIP.return_value = "127.0.0.1"
216207 request.args[b"access_token"] = [self.test_token]
217208 request.args[b"user_id"] = [masquerading_user_id]
218209 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
219 d = defer.ensureDeferred(self.auth.get_user_by_req(request))
220 self.failureResultOf(d, AuthError)
221
222 @defer.inlineCallbacks
210 self.get_failure(self.auth.get_user_by_req(request), AuthError)
211
223212 def test_get_user_from_macaroon(self):
224 self.store.get_user_by_access_token = Mock(
225 return_value=defer.succeed(
226 TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
227 )
213 self.store.get_user_by_access_token = simple_async_mock(
214 TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
228215 )
229216
230217 user_id = "@baldrick:matrix.org"
236223 macaroon.add_first_party_caveat("gen = 1")
237224 macaroon.add_first_party_caveat("type = access")
238225 macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
239 user_info = yield defer.ensureDeferred(
226 user_info = self.get_success(
240227 self.auth.get_user_by_access_token(macaroon.serialize())
241228 )
242229 self.assertEqual(user_id, user_info.user_id)
245232 # from the db.
246233 self.assertEqual(user_info.device_id, "device")
247234
248 @defer.inlineCallbacks
249235 def test_get_guest_user_from_macaroon(self):
250 self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
251 self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
236 self.store.get_user_by_id = simple_async_mock({"is_guest": True})
237 self.store.get_user_by_access_token = simple_async_mock(None)
252238
253239 user_id = "@baldrick:matrix.org"
254240 macaroon = pymacaroons.Macaroon(
262248 macaroon.add_first_party_caveat("guest = true")
263249 serialized = macaroon.serialize()
264250
265 user_info = yield defer.ensureDeferred(
266 self.auth.get_user_by_access_token(serialized)
267 )
251 user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
268252 self.assertEqual(user_id, user_info.user_id)
269253 self.assertTrue(user_info.is_guest)
270254 self.store.get_user_by_id.assert_called_with(user_id)
271255
272 @defer.inlineCallbacks
273256 def test_cannot_use_regular_token_as_guest(self):
274257 USER_ID = "@percy:matrix.org"
275 self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
276 self.store.get_device = Mock(return_value=defer.succeed(None))
277
278 token = yield defer.ensureDeferred(
258 self.store.add_access_token_to_user = simple_async_mock(None)
259 self.store.get_device = simple_async_mock(None)
260
261 token = self.get_success(
279262 self.hs.get_auth_handler().get_access_token_for_user_id(
280263 USER_ID, "DEVICE", valid_until_ms=None
281264 )
288271 puppets_user_id=None,
289272 )
290273
291 def get_user(tok):
274 async def get_user(tok):
292275 if token != tok:
293 return defer.succeed(None)
294 return defer.succeed(
295 TokenLookupResult(
296 user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
297 )
276 return None
277 return TokenLookupResult(
278 user_id=USER_ID,
279 is_guest=False,
280 token_id=1234,
281 device_id="DEVICE",
298282 )
299283
300284 self.store.get_user_by_access_token = get_user
301 self.store.get_user_by_id = Mock(
302 return_value=defer.succeed({"is_guest": False})
303 )
285 self.store.get_user_by_id = simple_async_mock({"is_guest": False})
304286
305287 # check the token works
306288 request = Mock(args={})
307289 request.args[b"access_token"] = [token.encode("ascii")]
308290 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
309 requester = yield defer.ensureDeferred(
291 requester = self.get_success(
310292 self.auth.get_user_by_req(request, allow_guest=True)
311293 )
312294 self.assertEqual(UserID.from_string(USER_ID), requester.user)
322304 request.args[b"access_token"] = [guest_tok.encode("ascii")]
323305 request.requestHeaders.getRawHeaders = mock_getRawHeaders()
324306
325 with self.assertRaises(InvalidClientCredentialsError) as cm:
326 yield defer.ensureDeferred(
327 self.auth.get_user_by_req(request, allow_guest=True)
328 )
329
330 self.assertEqual(401, cm.exception.code)
331 self.assertEqual("Guest access token used for regular user", cm.exception.msg)
307 cm = self.get_failure(
308 self.auth.get_user_by_req(request, allow_guest=True),
309 InvalidClientCredentialsError,
310 )
311
312 self.assertEqual(401, cm.value.code)
313 self.assertEqual("Guest access token used for regular user", cm.value.msg)
332314
333315 self.store.get_user_by_id.assert_called_with(USER_ID)
334316
335 @defer.inlineCallbacks
336317 def test_blocking_mau(self):
337318 self.auth_blocking._limit_usage_by_mau = False
338319 self.auth_blocking._max_mau_value = 50
340321 small_number_of_users = 1
341322
342323 # Ensure no error thrown
343 yield defer.ensureDeferred(self.auth.check_auth_blocking())
324 self.get_success(self.auth.check_auth_blocking())
344325
345326 self.auth_blocking._limit_usage_by_mau = True
346327
347 self.store.get_monthly_active_count = Mock(
348 return_value=defer.succeed(lots_of_users)
349 )
350
351 with self.assertRaises(ResourceLimitError) as e:
352 yield defer.ensureDeferred(self.auth.check_auth_blocking())
353 self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
354 self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
355 self.assertEquals(e.exception.code, 403)
328 self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
329
330 e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
331 self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
332 self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
333 self.assertEquals(e.value.code, 403)
356334
357335 # Ensure does not throw an error
358 self.store.get_monthly_active_count = Mock(
359 return_value=defer.succeed(small_number_of_users)
360 )
361 yield defer.ensureDeferred(self.auth.check_auth_blocking())
362
363 @defer.inlineCallbacks
336 self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
337 self.get_success(self.auth.check_auth_blocking())
338
364339 def test_blocking_mau__depending_on_user_type(self):
365340 self.auth_blocking._max_mau_value = 50
366341 self.auth_blocking._limit_usage_by_mau = True
367342
368 self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
343 self.store.get_monthly_active_count = simple_async_mock(100)
369344 # Support users allowed
370 yield defer.ensureDeferred(
371 self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
372 )
373 self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
345 self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
346 self.store.get_monthly_active_count = simple_async_mock(100)
374347 # Bots not allowed
375 with self.assertRaises(ResourceLimitError):
376 yield defer.ensureDeferred(
377 self.auth.check_auth_blocking(user_type=UserTypes.BOT)
378 )
379 self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
348 self.get_failure(
349 self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
350 )
351 self.store.get_monthly_active_count = simple_async_mock(100)
380352 # Real users not allowed
381 with self.assertRaises(ResourceLimitError):
382 yield defer.ensureDeferred(self.auth.check_auth_blocking())
383
384 @defer.inlineCallbacks
353 self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
354
385355 def test_reserved_threepid(self):
386356 self.auth_blocking._limit_usage_by_mau = True
387357 self.auth_blocking._max_mau_value = 1
388 self.store.get_monthly_active_count = lambda: defer.succeed(2)
358 self.store.get_monthly_active_count = simple_async_mock(2)
389359 threepid = {"medium": "email", "address": "reserved@server.com"}
390360 unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
391361 self.auth_blocking._mau_limits_reserved_threepids = [threepid]
392362
393 with self.assertRaises(ResourceLimitError):
394 yield defer.ensureDeferred(self.auth.check_auth_blocking())
395
396 with self.assertRaises(ResourceLimitError):
397 yield defer.ensureDeferred(
398 self.auth.check_auth_blocking(threepid=unknown_threepid)
399 )
400
401 yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
402
403 @defer.inlineCallbacks
363 self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
364
365 self.get_failure(
366 self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
367 )
368
369 self.get_success(self.auth.check_auth_blocking(threepid=threepid))
370
404371 def test_hs_disabled(self):
405372 self.auth_blocking._hs_disabled = True
406373 self.auth_blocking._hs_disabled_message = "Reason for being disabled"
407 with self.assertRaises(ResourceLimitError) as e:
408 yield defer.ensureDeferred(self.auth.check_auth_blocking())
409 self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
410 self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
411 self.assertEquals(e.exception.code, 403)
412
413 @defer.inlineCallbacks
374 e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
375 self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
376 self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
377 self.assertEquals(e.value.code, 403)
378
414379 def test_hs_disabled_no_server_notices_user(self):
415380 """Check that 'hs_disabled_message' works correctly when there is no
416381 server_notices user.
421386
422387 self.auth_blocking._hs_disabled = True
423388 self.auth_blocking._hs_disabled_message = "Reason for being disabled"
424 with self.assertRaises(ResourceLimitError) as e:
425 yield defer.ensureDeferred(self.auth.check_auth_blocking())
426 self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
427 self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
428 self.assertEquals(e.exception.code, 403)
429
430 @defer.inlineCallbacks
389 e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
390 self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
391 self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
392 self.assertEquals(e.value.code, 403)
393
431394 def test_server_notices_mxid_special_cased(self):
432395 self.auth_blocking._hs_disabled = True
433396 user = "@user:server"
434397 self.auth_blocking._server_notices_mxid = user
435398 self.auth_blocking._hs_disabled_message = "Reason for being disabled"
436 yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
399 self.get_success(self.auth.check_auth_blocking(user))
1717
1818 import jsonschema
1919
20 from twisted.internet import defer
21
2220 from synapse.api.constants import EventContentFields
2321 from synapse.api.errors import SynapseError
2422 from synapse.api.filtering import Filter
2523 from synapse.events import make_event_from_dict
2624
2725 from tests import unittest
28 from tests.utils import setup_test_homeserver
2926
3027 user_localpart = "test_user"
3128
3835 return make_event_from_dict(kwargs)
3936
4037
41 class FilteringTestCase(unittest.TestCase):
42 def setUp(self):
43 hs = setup_test_homeserver(self.addCleanup)
38 class FilteringTestCase(unittest.HomeserverTestCase):
39 def prepare(self, reactor, clock, hs):
4440 self.filtering = hs.get_filtering()
4541 self.datastore = hs.get_datastore()
4642
350346
351347 self.assertTrue(Filter(definition).check(event))
352348
353 @defer.inlineCallbacks
354349 def test_filter_presence_match(self):
355350 user_filter_json = {"presence": {"types": ["m.*"]}}
356 filter_id = yield defer.ensureDeferred(
351 filter_id = self.get_success(
357352 self.datastore.add_user_filter(
358353 user_localpart=user_localpart, user_filter=user_filter_json
359354 )
361356 event = MockEvent(sender="@foo:bar", type="m.profile")
362357 events = [event]
363358
364 user_filter = yield defer.ensureDeferred(
359 user_filter = self.get_success(
365360 self.filtering.get_user_filter(
366361 user_localpart=user_localpart, filter_id=filter_id
367362 )
370365 results = user_filter.filter_presence(events=events)
371366 self.assertEquals(events, results)
372367
373 @defer.inlineCallbacks
374368 def test_filter_presence_no_match(self):
375369 user_filter_json = {"presence": {"types": ["m.*"]}}
376370
377 filter_id = yield defer.ensureDeferred(
371 filter_id = self.get_success(
378372 self.datastore.add_user_filter(
379373 user_localpart=user_localpart + "2", user_filter=user_filter_json
380374 )
386380 )
387381 events = [event]
388382
389 user_filter = yield defer.ensureDeferred(
383 user_filter = self.get_success(
390384 self.filtering.get_user_filter(
391385 user_localpart=user_localpart + "2", filter_id=filter_id
392386 )
395389 results = user_filter.filter_presence(events=events)
396390 self.assertEquals([], results)
397391
398 @defer.inlineCallbacks
399392 def test_filter_room_state_match(self):
400393 user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
401 filter_id = yield defer.ensureDeferred(
394 filter_id = self.get_success(
402395 self.datastore.add_user_filter(
403396 user_localpart=user_localpart, user_filter=user_filter_json
404397 )
406399 event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
407400 events = [event]
408401
409 user_filter = yield defer.ensureDeferred(
402 user_filter = self.get_success(
410403 self.filtering.get_user_filter(
411404 user_localpart=user_localpart, filter_id=filter_id
412405 )
415408 results = user_filter.filter_room_state(events=events)
416409 self.assertEquals(events, results)
417410
418 @defer.inlineCallbacks
419411 def test_filter_room_state_no_match(self):
420412 user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
421 filter_id = yield defer.ensureDeferred(
413 filter_id = self.get_success(
422414 self.datastore.add_user_filter(
423415 user_localpart=user_localpart, user_filter=user_filter_json
424416 )
428420 )
429421 events = [event]
430422
431 user_filter = yield defer.ensureDeferred(
423 user_filter = self.get_success(
432424 self.filtering.get_user_filter(
433425 user_localpart=user_localpart, filter_id=filter_id
434426 )
453445
454446 self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
455447
456 @defer.inlineCallbacks
457448 def test_add_filter(self):
458449 user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
459450
460 filter_id = yield defer.ensureDeferred(
451 filter_id = self.get_success(
461452 self.filtering.add_user_filter(
462453 user_localpart=user_localpart, user_filter=user_filter_json
463454 )
467458 self.assertEquals(
468459 user_filter_json,
469460 (
470 yield defer.ensureDeferred(
461 self.get_success(
471462 self.datastore.get_user_filter(
472463 user_localpart=user_localpart, filter_id=0
473464 )
475466 ),
476467 )
477468
478 @defer.inlineCallbacks
479469 def test_get_filter(self):
480470 user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
481471
482 filter_id = yield defer.ensureDeferred(
472 filter_id = self.get_success(
483473 self.datastore.add_user_filter(
484474 user_localpart=user_localpart, user_filter=user_filter_json
485475 )
486476 )
487477
488 filter = yield defer.ensureDeferred(
478 filter = self.get_success(
489479 self.filtering.get_user_filter(
490480 user_localpart=user_localpart, filter_id=filter_id
491481 )
4242
4343 def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
4444 appservice = ApplicationService(
45 None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
45 None,
46 "example.com",
47 id="foo",
48 rate_limited=True,
49 sender="@as:example.com",
4650 )
4751 as_requester = create_requester("@user:example.com", app_service=appservice)
4852
6771
6872 def test_allowed_appservice_via_can_requester_do_action(self):
6973 appservice = ApplicationService(
70 None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
74 None,
75 "example.com",
76 id="foo",
77 rate_limited=False,
78 sender="@as:example.com",
7179 )
7280 as_requester = create_requester("@user:example.com", app_service=appservice)
7381
112120 limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
113121
114122 # First attempt should be allowed
115 allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,)
123 allowed, time_allowed = limiter.can_do_action(
124 ("test_id",),
125 _time_now_s=0,
126 )
116127 self.assertTrue(allowed)
117128 self.assertEqual(10.0, time_allowed)
118129
119130 # Second attempt, 1s later, will fail
120 allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,)
131 allowed, time_allowed = limiter.can_do_action(
132 ("test_id",),
133 _time_now_s=1,
134 )
121135 self.assertFalse(allowed)
122136 self.assertEqual(10.0, time_allowed)
123137
126126 self.assertEqual(cache.max_size, 150)
127127
128128 def test_cache_with_asterisk_in_name(self):
129 """Some caches have asterisks in their name, test that they are set correctly.
130 """
129 """Some caches have asterisks in their name, test that they are set correctly."""
131130
132131 config = {
133132 "caches": {
163162 t.read_config(config, config_dir_path="", data_dir_path="")
164163
165164 cache = LruCache(
166 max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False,
165 max_size=t.caches.event_cache_size,
166 apply_cache_factor_from_config=False,
167167 )
168168 add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
169169
1414
1515 import yaml
1616
17 from synapse.config.server import ServerConfig, is_threepid_reserved
17 from synapse.config._base import ConfigError
18 from synapse.config.server import ServerConfig, generate_ip_set, is_threepid_reserved
1819
1920 from tests import unittest
2021
127128 )
128129
129130 self.assertEqual(conf["listeners"], expected_listeners)
131
132
133 class GenerateIpSetTestCase(unittest.TestCase):
134 def test_empty(self):
135 ip_set = generate_ip_set(())
136 self.assertFalse(ip_set)
137
138 ip_set = generate_ip_set((), ())
139 self.assertFalse(ip_set)
140
141 def test_generate(self):
142 """Check adding IPv4 and IPv6 addresses."""
143 # IPv4 address
144 ip_set = generate_ip_set(("1.2.3.4",))
145 self.assertEqual(len(ip_set.iter_cidrs()), 4)
146
147 # IPv4 CIDR
148 ip_set = generate_ip_set(("1.2.3.4/24",))
149 self.assertEqual(len(ip_set.iter_cidrs()), 4)
150
151 # IPv6 address
152 ip_set = generate_ip_set(("2001:db8::8a2e:370:7334",))
153 self.assertEqual(len(ip_set.iter_cidrs()), 1)
154
155 # IPv6 CIDR
156 ip_set = generate_ip_set(("2001:db8::/104",))
157 self.assertEqual(len(ip_set.iter_cidrs()), 1)
158
159 # The addresses can overlap OK.
160 ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
161 self.assertEqual(len(ip_set.iter_cidrs()), 4)
162
163 def test_extra(self):
164 """Extra IP addresses are treated the same."""
165 ip_set = generate_ip_set((), ("1.2.3.4",))
166 self.assertEqual(len(ip_set.iter_cidrs()), 4)
167
168 ip_set = generate_ip_set(("1.1.1.1",), ("1.2.3.4",))
169 self.assertEqual(len(ip_set.iter_cidrs()), 8)
170
171 # They can duplicate without error.
172 ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
173 self.assertEqual(len(ip_set.iter_cidrs()), 4)
174
175 def test_bad_value(self):
176 """An error should be raised if a bad value is passed in."""
177 with self.assertRaises(ConfigError):
178 generate_ip_set(("not-an-ip",))
179
180 with self.assertRaises(ConfigError):
181 generate_ip_set(("1.2.3.4/128",))
182
183 with self.assertRaises(ConfigError):
184 generate_ip_set((":::",))
185
186 # The following get treated as empty data.
187 self.assertFalse(generate_ip_set(None))
188 self.assertFalse(generate_ip_set({}))
399399 )
400400
401401 def build_perspectives_response(
402 self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
402 self,
403 server_name: str,
404 signing_key: SigningKey,
405 valid_until_ts: int,
403406 ) -> dict:
404407 """
405408 Build a valid perspectives server response to a request for the given key
454457 VALID_UNTIL_TS = 200 * 1000
455458
456459 response = self.build_perspectives_response(
457 SERVER_NAME, testkey, VALID_UNTIL_TS,
460 SERVER_NAME,
461 testkey,
462 VALID_UNTIL_TS,
458463 )
459464
460465 self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
4242
4343 event, context = self.get_success(
4444 create_event(
45 self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
45 self.hs,
46 room_id=self.room_id,
47 type="m.test",
48 sender=self.user_id,
4649 )
4750 )
4851
149149 )
150150
151151 # Artificially raise the complexity
152 self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable(
153 600
152 self.hs.get_datastore().get_current_state_event_counts = (
153 lambda x: make_awaitable(600)
154154 )
155155
156156 d = handler._remote_join(
278278
279279 ret = self.get_success(
280280 e2e_handler.upload_signatures_for_device_keys(
281 u1, {u1: {"D1": d1_json, "D2": d2_json}},
281 u1,
282 {u1: {"D1": d1_json, "D2": d2_json}},
282283 )
283284 )
284285 self.assertEqual(ret["failures"], {})
485486 self.assertGreaterEqual(content["stream_id"], prev_stream_id)
486487 return content["stream_id"]
487488
488 def check_signing_key_update_txn(self, txn: JsonDict,) -> None:
489 """Check that the txn has an EDU with a signing key update.
490 """
489 def check_signing_key_update_txn(
490 self,
491 txn: JsonDict,
492 ) -> None:
493 """Check that the txn has an EDU with a signing key update."""
491494 edus = txn["edus"]
492495 self.assertEqual(len(edus), 1)
493496
501504
502505 self.get_success(
503506 self.hs.get_e2e_keys_handler().upload_keys_for_user(
504 user_id, device_id, {"device_keys": device_dict},
507 user_id,
508 device_id,
509 {"device_keys": device_dict},
505510 )
506511 )
507512 return sk
4343 self.token2 = self.login("user2", "password")
4444
4545 def test_single_public_joined_room(self):
46 """Test that we write *all* events for a public room
47 """
46 """Test that we write *all* events for a public room"""
4847 room_id = self.helper.create_room_as(
4948 self.user1, tok=self.token1, is_public=True
5049 )
115114 self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
116115
117116 def test_single_left_room(self):
118 """Tests that we don't see events in the room after we leave.
119 """
117 """Tests that we don't see events in the room after we leave."""
120118 room_id = self.helper.create_room_as(self.user1, tok=self.token1)
121119 self.helper.send(room_id, body="Hello!", tok=self.token1)
122120 self.helper.join(room_id, self.user2, tok=self.token2)
189187 self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
190188
191189 def test_invite(self):
192 """Tests that pending invites get handled correctly.
193 """
190 """Tests that pending invites get handled correctly."""
194191 room_id = self.helper.create_room_as(self.user1, tok=self.token1)
195192 self.helper.send(room_id, body="Hello!", tok=self.token1)
196193 self.helper.invite(room_id, self.user1, self.user2, tok=self.token1)
3434 self.mock_scheduler = Mock()
3535 hs = Mock()
3636 hs.get_datastore.return_value = self.mock_store
37 self.mock_store.get_received_ts.return_value = defer.succeed(0)
38 self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None)
37 self.mock_store.get_received_ts.return_value = make_awaitable(0)
38 self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
3939 hs.get_application_service_api.return_value = self.mock_as_api
4040 hs.get_application_service_scheduler.return_value = self.mock_scheduler
4141 hs.get_clock.return_value = MockClock()
4949 self._mkservice(is_interested=False),
5050 ]
5151
52 self.mock_as_api.query_user.return_value = defer.succeed(True)
52 self.mock_as_api.query_user.return_value = make_awaitable(True)
5353 self.mock_store.get_app_services.return_value = services
54 self.mock_store.get_user_by_id.return_value = defer.succeed([])
54 self.mock_store.get_user_by_id.return_value = make_awaitable([])
5555
5656 event = Mock(
5757 sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
5858 )
5959 self.mock_store.get_new_events_for_appservice.side_effect = [
60 defer.succeed((0, [event])),
61 defer.succeed((0, [])),
60 make_awaitable((0, [event])),
61 make_awaitable((0, [])),
6262 ]
6363 self.handler.notify_interested_services(RoomStreamToken(None, 0))
6464
7171 services = [self._mkservice(is_interested=True)]
7272 services[0].is_interested_in_user.return_value = True
7373 self.mock_store.get_app_services.return_value = services
74 self.mock_store.get_user_by_id.return_value = defer.succeed(None)
74 self.mock_store.get_user_by_id.return_value = make_awaitable(None)
7575
7676 event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
77 self.mock_as_api.query_user.return_value = defer.succeed(True)
77 self.mock_as_api.query_user.return_value = make_awaitable(True)
7878 self.mock_store.get_new_events_for_appservice.side_effect = [
79 defer.succeed((0, [event])),
80 defer.succeed((0, [])),
79 make_awaitable((0, [event])),
80 make_awaitable((0, [])),
8181 ]
8282
8383 self.handler.notify_interested_services(RoomStreamToken(None, 0))
8989 services = [self._mkservice(is_interested=True)]
9090 services[0].is_interested_in_user.return_value = True
9191 self.mock_store.get_app_services.return_value = services
92 self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id})
92 self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
9393
9494 event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
95 self.mock_as_api.query_user.return_value = defer.succeed(True)
95 self.mock_as_api.query_user.return_value = make_awaitable(True)
9696 self.mock_store.get_new_events_for_appservice.side_effect = [
97 defer.succeed((0, [event])),
98 defer.succeed((0, [])),
97 make_awaitable((0, [event])),
98 make_awaitable((0, [])),
9999 ]
100100
101101 self.handler.notify_interested_services(RoomStreamToken(None, 0))
105105 "query_user called when it shouldn't have been.",
106106 )
107107
108 @defer.inlineCallbacks
109108 def test_query_room_alias_exists(self):
110109 room_alias_str = "#foo:bar"
111110 room_alias = Mock()
126125 Mock(room_id=room_id, servers=servers)
127126 )
128127
129 result = yield defer.ensureDeferred(
130 self.handler.query_room_alias_exists(room_alias)
128 result = self.successResultOf(
129 defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
131130 )
132131
133132 self.mock_as_api.query_alias.assert_called_once_with(
1515
1616 import pymacaroons
1717
18 from twisted.internet import defer
19
20 import synapse
21 import synapse.api.errors
22 from synapse.api.errors import ResourceLimitError
18 from synapse.api.errors import AuthError, ResourceLimitError
2319
2420 from tests import unittest
2521 from tests.test_utils import make_awaitable
26 from tests.utils import setup_test_homeserver
27
28
29 class AuthTestCase(unittest.TestCase):
30 @defer.inlineCallbacks
31 def setUp(self):
32 self.hs = yield setup_test_homeserver(self.addCleanup)
33 self.auth_handler = self.hs.get_auth_handler()
34 self.macaroon_generator = self.hs.get_macaroon_generator()
22
23
24 class AuthTestCase(unittest.HomeserverTestCase):
25 def prepare(self, reactor, clock, hs):
26 self.auth_handler = hs.get_auth_handler()
27 self.macaroon_generator = hs.get_macaroon_generator()
3528
3629 # MAU tests
3730 # AuthBlocking reads from the hs' config on initialization. We need to
3831 # modify its config instead of the hs'
39 self.auth_blocking = self.hs.get_auth()._auth_blocking
32 self.auth_blocking = hs.get_auth()._auth_blocking
4033 self.auth_blocking._max_mau_value = 50
4134
4235 self.small_number_of_users = 1
5144 self.fail("some_user was not in %s" % macaroon.inspect())
5245
5346 def test_macaroon_caveats(self):
54 self.hs.get_clock().now = 5000
55
5647 token = self.macaroon_generator.generate_access_token("a_user")
5748 macaroon = pymacaroons.Macaroon.deserialize(token)
5849
7566 v.satisfy_general(verify_nonce)
7667 v.verify(macaroon, self.hs.config.macaroon_secret_key)
7768
78 @defer.inlineCallbacks
7969 def test_short_term_login_token_gives_user_id(self):
80 self.hs.get_clock().now = 1000
81
8270 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
83 user_id = yield defer.ensureDeferred(
71 user_id = self.get_success(
8472 self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
8573 )
8674 self.assertEqual("a_user", user_id)
8775
8876 # when we advance the clock, the token should be rejected
89 self.hs.get_clock().now = 6000
90 with self.assertRaises(synapse.api.errors.AuthError):
91 yield defer.ensureDeferred(
92 self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
93 )
94
95 @defer.inlineCallbacks
77 self.reactor.advance(6)
78 self.get_failure(
79 self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
80 AuthError,
81 )
82
9683 def test_short_term_login_token_cannot_replace_user_id(self):
9784 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
9885 macaroon = pymacaroons.Macaroon.deserialize(token)
9986
100 user_id = yield defer.ensureDeferred(
87 user_id = self.get_success(
10188 self.auth_handler.validate_short_term_login_token_and_get_user_id(
10289 macaroon.serialize()
10390 )
10895 # user_id.
10996 macaroon.add_first_party_caveat("user_id = b_user")
11097
111 with self.assertRaises(synapse.api.errors.AuthError):
112 yield defer.ensureDeferred(
113 self.auth_handler.validate_short_term_login_token_and_get_user_id(
114 macaroon.serialize()
115 )
116 )
117
118 @defer.inlineCallbacks
98 self.get_failure(
99 self.auth_handler.validate_short_term_login_token_and_get_user_id(
100 macaroon.serialize()
101 ),
102 AuthError,
103 )
104
119105 def test_mau_limits_disabled(self):
120106 self.auth_blocking._limit_usage_by_mau = False
121107 # Ensure does not throw exception
122 yield defer.ensureDeferred(
123 self.auth_handler.get_access_token_for_user_id(
124 "user_a", device_id=None, valid_until_ms=None
125 )
126 )
127
128 yield defer.ensureDeferred(
129 self.auth_handler.validate_short_term_login_token_and_get_user_id(
130 self._get_macaroon().serialize()
131 )
132 )
133
134 @defer.inlineCallbacks
108 self.get_success(
109 self.auth_handler.get_access_token_for_user_id(
110 "user_a", device_id=None, valid_until_ms=None
111 )
112 )
113
114 self.get_success(
115 self.auth_handler.validate_short_term_login_token_and_get_user_id(
116 self._get_macaroon().serialize()
117 )
118 )
119
135120 def test_mau_limits_exceeded_large(self):
136121 self.auth_blocking._limit_usage_by_mau = True
137122 self.hs.get_datastore().get_monthly_active_count = Mock(
138123 return_value=make_awaitable(self.large_number_of_users)
139124 )
140125
141 with self.assertRaises(ResourceLimitError):
142 yield defer.ensureDeferred(
143 self.auth_handler.get_access_token_for_user_id(
144 "user_a", device_id=None, valid_until_ms=None
145 )
146 )
126 self.get_failure(
127 self.auth_handler.get_access_token_for_user_id(
128 "user_a", device_id=None, valid_until_ms=None
129 ),
130 ResourceLimitError,
131 )
147132
148133 self.hs.get_datastore().get_monthly_active_count = Mock(
149134 return_value=make_awaitable(self.large_number_of_users)
150135 )
151 with self.assertRaises(ResourceLimitError):
152 yield defer.ensureDeferred(
153 self.auth_handler.validate_short_term_login_token_and_get_user_id(
154 self._get_macaroon().serialize()
155 )
156 )
157
158 @defer.inlineCallbacks
136 self.get_failure(
137 self.auth_handler.validate_short_term_login_token_and_get_user_id(
138 self._get_macaroon().serialize()
139 ),
140 ResourceLimitError,
141 )
142
159143 def test_mau_limits_parity(self):
144 # Ensure we're not at the unix epoch.
145 self.reactor.advance(1)
160146 self.auth_blocking._limit_usage_by_mau = True
161147
148 # Set the server to be at the edge of too many users.
149 self.hs.get_datastore().get_monthly_active_count = Mock(
150 return_value=make_awaitable(self.auth_blocking._max_mau_value)
151 )
152
162153 # If not in monthly active cohort
163 self.hs.get_datastore().get_monthly_active_count = Mock(
164 return_value=make_awaitable(self.auth_blocking._max_mau_value)
165 )
166 with self.assertRaises(ResourceLimitError):
167 yield defer.ensureDeferred(
168 self.auth_handler.get_access_token_for_user_id(
169 "user_a", device_id=None, valid_until_ms=None
170 )
171 )
172
173 self.hs.get_datastore().get_monthly_active_count = Mock(
174 return_value=make_awaitable(self.auth_blocking._max_mau_value)
175 )
176 with self.assertRaises(ResourceLimitError):
177 yield defer.ensureDeferred(
178 self.auth_handler.validate_short_term_login_token_and_get_user_id(
179 self._get_macaroon().serialize()
180 )
181 )
154 self.get_failure(
155 self.auth_handler.get_access_token_for_user_id(
156 "user_a", device_id=None, valid_until_ms=None
157 ),
158 ResourceLimitError,
159 )
160 self.get_failure(
161 self.auth_handler.validate_short_term_login_token_and_get_user_id(
162 self._get_macaroon().serialize()
163 ),
164 ResourceLimitError,
165 )
166
182167 # If in monthly active cohort
183168 self.hs.get_datastore().user_last_seen_monthly_active = Mock(
184 return_value=make_awaitable(self.hs.get_clock().time_msec())
185 )
186 self.hs.get_datastore().get_monthly_active_count = Mock(
187 return_value=make_awaitable(self.auth_blocking._max_mau_value)
188 )
189 yield defer.ensureDeferred(
190 self.auth_handler.get_access_token_for_user_id(
191 "user_a", device_id=None, valid_until_ms=None
192 )
193 )
194 self.hs.get_datastore().user_last_seen_monthly_active = Mock(
195 return_value=make_awaitable(self.hs.get_clock().time_msec())
196 )
197 self.hs.get_datastore().get_monthly_active_count = Mock(
198 return_value=make_awaitable(self.auth_blocking._max_mau_value)
199 )
200 yield defer.ensureDeferred(
201 self.auth_handler.validate_short_term_login_token_and_get_user_id(
202 self._get_macaroon().serialize()
203 )
204 )
205
206 @defer.inlineCallbacks
169 return_value=make_awaitable(self.clock.time_msec())
170 )
171 self.get_success(
172 self.auth_handler.get_access_token_for_user_id(
173 "user_a", device_id=None, valid_until_ms=None
174 )
175 )
176 self.get_success(
177 self.auth_handler.validate_short_term_login_token_and_get_user_id(
178 self._get_macaroon().serialize()
179 )
180 )
181
207182 def test_mau_limits_not_exceeded(self):
208183 self.auth_blocking._limit_usage_by_mau = True
209184
211186 return_value=make_awaitable(self.small_number_of_users)
212187 )
213188 # Ensure does not raise exception
214 yield defer.ensureDeferred(
189 self.get_success(
215190 self.auth_handler.get_access_token_for_user_id(
216191 "user_a", device_id=None, valid_until_ms=None
217192 )
220195 self.hs.get_datastore().get_monthly_active_count = Mock(
221196 return_value=make_awaitable(self.small_number_of_users)
222197 )
223 yield defer.ensureDeferred(
198 self.get_success(
224199 self.auth_handler.validate_short_term_login_token_and_get_user_id(
225200 self._get_macaroon().serialize()
226201 )
1515 from synapse.handlers.cas_handler import CasResponse
1616
1717 from tests.test_utils import simple_async_mock
18 from tests.unittest import HomeserverTestCase
18 from tests.unittest import HomeserverTestCase, override_config
1919
2020 # These are a few constants that are used as config parameters in the tests.
2121 BASE_URL = "https://synapse/"
3131 "server_url": SERVER_URL,
3232 "service_url": BASE_URL,
3333 }
34
35 # Update this config with what's in the default config so that
36 # override_config works as expected.
37 cas_config.update(config.get("cas_config", {}))
3438 config["cas_config"] = cas_config
3539
3640 return config
114118 "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
115119 )
116120
121 @override_config(
122 {
123 "cas_config": {
124 "required_attributes": {"userGroup": "staff", "department": None}
125 }
126 }
127 )
128 def test_required_attributes(self):
129 """The required attributes must be met from the CAS response."""
130
131 # stub out the auth handler
132 auth_handler = self.hs.get_auth_handler()
133 auth_handler.complete_sso_login = simple_async_mock()
134
135 # The response doesn't have the proper userGroup or department.
136 cas_response = CasResponse("test_user", {})
137 request = _mock_request()
138 self.get_success(
139 self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
140 )
141 auth_handler.complete_sso_login.assert_not_called()
142
143 # The response doesn't have any department.
144 cas_response = CasResponse("test_user", {"userGroup": "staff"})
145 request.reset_mock()
146 self.get_success(
147 self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
148 )
149 auth_handler.complete_sso_login.assert_not_called()
150
151 # Add the proper attributes and it should succeed.
152 cas_response = CasResponse(
153 "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
154 )
155 request.reset_mock()
156 self.get_success(
157 self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
158 )
159
160 # check that the auth handler got called as expected
161 auth_handler.complete_sso_login.assert_called_once_with(
162 "@test_user:test", request, "redirect_uri", None, new_user=True
163 )
164
117165
118166 def _mock_request():
119167 """Returns a mock which will stand in as a SynapseRequest"""
120 return Mock(spec=["getClientIP", "getHeader"])
168 return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
259259 # Create a new login for the user and dehydrated the device
260260 device_id, access_token = self.get_success(
261261 self.registration.register_device(
262 user_id=user_id, device_id=None, initial_display_name="new device",
262 user_id=user_id,
263 device_id=None,
264 initial_display_name="new device",
263265 )
264266 )
265267
130130 """A user can create an alias for a room they're in."""
131131 self.get_success(
132132 self.handler.create_association(
133 create_requester(self.test_user), self.room_alias, self.room_id,
133 create_requester(self.test_user),
134 self.room_alias,
135 self.room_id,
134136 )
135137 )
136138
142144
143145 self.get_failure(
144146 self.handler.create_association(
145 create_requester(self.test_user), self.room_alias, other_room_id,
147 create_requester(self.test_user),
148 self.room_alias,
149 other_room_id,
146150 ),
147151 synapse.api.errors.SynapseError,
148152 )
155159
156160 self.get_success(
157161 self.handler.create_association(
158 create_requester(self.admin_user), self.room_alias, other_room_id,
162 create_requester(self.admin_user),
163 self.room_alias,
164 other_room_id,
159165 )
160166 )
161167
274280
275281
276282 class CanonicalAliasTestCase(unittest.HomeserverTestCase):
277 """Test modifications of the canonical alias when delete aliases.
278 """
283 """Test modifications of the canonical alias when delete aliases."""
279284
280285 servlets = [
281286 synapse.rest.admin.register_servlets,
316321 def _set_canonical_alias(self, content):
317322 """Configure the canonical alias state on the room."""
318323 self.helper.send_state(
319 self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
324 self.room_id,
325 "m.room.canonical_alias",
326 content,
327 tok=self.admin_user_tok,
320328 )
321329
322330 def _get_canonical_alias(self):
1717
1818 from signedjson import key as key, sign as sign
1919
20 from twisted.internet import defer
21
22 import synapse.handlers.e2e_keys
23 import synapse.storage
24 from synapse.api import errors
2520 from synapse.api.constants import RoomEncryptionAlgorithms
26
27 from tests import unittest, utils
28
29
30 class E2eKeysHandlerTestCase(unittest.TestCase):
31 def __init__(self, *args, **kwargs):
32 super().__init__(*args, **kwargs)
33 self.hs = None # type: synapse.server.HomeServer
34 self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
35 self.store = None # type: synapse.storage.Storage
36
37 @defer.inlineCallbacks
38 def setUp(self):
39 self.hs = yield utils.setup_test_homeserver(
40 self.addCleanup, federation_client=mock.Mock()
41 )
42 self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
21 from synapse.api.errors import Codes, SynapseError
22
23 from tests import unittest
24
25
26 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
27 def make_homeserver(self, reactor, clock):
28 return self.setup_test_homeserver(federation_client=mock.Mock())
29
30 def prepare(self, reactor, clock, hs):
31 self.handler = hs.get_e2e_keys_handler()
4332 self.store = self.hs.get_datastore()
4433
45 @defer.inlineCallbacks
4634 def test_query_local_devices_no_devices(self):
47 """If the user has no devices, we expect an empty list.
48 """
49 local_user = "@boris:" + self.hs.hostname
50 res = yield defer.ensureDeferred(
51 self.handler.query_local_devices({local_user: None})
52 )
35 """If the user has no devices, we expect an empty list."""
36 local_user = "@boris:" + self.hs.hostname
37 res = self.get_success(self.handler.query_local_devices({local_user: None}))
5338 self.assertDictEqual(res, {local_user: {}})
5439
55 @defer.inlineCallbacks
5640 def test_reupload_one_time_keys(self):
5741 """we should be able to re-upload the same keys"""
5842 local_user = "@boris:" + self.hs.hostname
6347 "alg2:k3": {"key": "key3"},
6448 }
6549
66 res = yield defer.ensureDeferred(
50 res = self.get_success(
6751 self.handler.upload_keys_for_user(
6852 local_user, device_id, {"one_time_keys": keys}
6953 )
7256
7357 # we should be able to change the signature without a problem
7458 keys["alg2:k2"]["signatures"]["k1"] = "sig2"
75 res = yield defer.ensureDeferred(
59 res = self.get_success(
7660 self.handler.upload_keys_for_user(
7761 local_user, device_id, {"one_time_keys": keys}
7862 )
7963 )
8064 self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
8165
82 @defer.inlineCallbacks
8366 def test_change_one_time_keys(self):
8467 """attempts to change one-time-keys should be rejected"""
8568
9174 "alg2:k3": {"key": "key3"},
9275 }
9376
94 res = yield defer.ensureDeferred(
77 res = self.get_success(
9578 self.handler.upload_keys_for_user(
9679 local_user, device_id, {"one_time_keys": keys}
9780 )
9881 )
9982 self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
10083
101 try:
102 yield defer.ensureDeferred(
103 self.handler.upload_keys_for_user(
104 local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
105 )
106 )
107 self.fail("No error when changing string key")
108 except errors.SynapseError:
109 pass
110
111 try:
112 yield defer.ensureDeferred(
113 self.handler.upload_keys_for_user(
114 local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
115 )
116 )
117 self.fail("No error when replacing dict key with string")
118 except errors.SynapseError:
119 pass
120
121 try:
122 yield defer.ensureDeferred(
123 self.handler.upload_keys_for_user(
124 local_user,
125 device_id,
126 {"one_time_keys": {"alg1:k1": {"key": "key"}}},
127 )
128 )
129 self.fail("No error when replacing string key with dict")
130 except errors.SynapseError:
131 pass
132
133 try:
134 yield defer.ensureDeferred(
135 self.handler.upload_keys_for_user(
136 local_user,
137 device_id,
138 {
139 "one_time_keys": {
140 "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
141 }
142 },
143 )
144 )
145 self.fail("No error when replacing dict key")
146 except errors.SynapseError:
147 pass
148
149 @defer.inlineCallbacks
84 # Error when changing string key
85 self.get_failure(
86 self.handler.upload_keys_for_user(
87 local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
88 ),
89 SynapseError,
90 )
91
92 # Error when replacing dict key with strin
93 self.get_failure(
94 self.handler.upload_keys_for_user(
95 local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
96 ),
97 SynapseError,
98 )
99
100 # Error when replacing string key with dict
101 self.get_failure(
102 self.handler.upload_keys_for_user(
103 local_user,
104 device_id,
105 {"one_time_keys": {"alg1:k1": {"key": "key"}}},
106 ),
107 SynapseError,
108 )
109
110 # Error when replacing dict key
111 self.get_failure(
112 self.handler.upload_keys_for_user(
113 local_user,
114 device_id,
115 {
116 "one_time_keys": {
117 "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
118 }
119 },
120 ),
121 SynapseError,
122 )
123
150124 def test_claim_one_time_key(self):
151125 local_user = "@boris:" + self.hs.hostname
152126 device_id = "xyz"
153127 keys = {"alg1:k1": "key1"}
154128
155 res = yield defer.ensureDeferred(
129 res = self.get_success(
156130 self.handler.upload_keys_for_user(
157131 local_user, device_id, {"one_time_keys": keys}
158132 )
159133 )
160134 self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
161135
162 res2 = yield defer.ensureDeferred(
136 res2 = self.get_success(
163137 self.handler.claim_one_time_keys(
164138 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
165139 )
172146 },
173147 )
174148
175 @defer.inlineCallbacks
176149 def test_fallback_key(self):
177150 local_user = "@boris:" + self.hs.hostname
178151 device_id = "xyz"
180153 otk = {"alg1:k2": "key2"}
181154
182155 # we shouldn't have any unused fallback keys yet
183 res = yield defer.ensureDeferred(
156 res = self.get_success(
184157 self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
185158 )
186159 self.assertEqual(res, [])
187160
188 yield defer.ensureDeferred(
161 self.get_success(
189162 self.handler.upload_keys_for_user(
190163 local_user,
191164 device_id,
194167 )
195168
196169 # we should now have an unused alg1 key
197 res = yield defer.ensureDeferred(
170 res = self.get_success(
198171 self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
199172 )
200173 self.assertEqual(res, ["alg1"])
201174
202175 # claiming an OTK when no OTKs are available should return the fallback
203176 # key
204 res = yield defer.ensureDeferred(
177 res = self.get_success(
205178 self.handler.claim_one_time_keys(
206179 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
207180 )
212185 )
213186
214187 # we shouldn't have any unused fallback keys again
215 res = yield defer.ensureDeferred(
188 res = self.get_success(
216189 self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
217190 )
218191 self.assertEqual(res, [])
219192
220193 # claiming an OTK again should return the same fallback key
221 res = yield defer.ensureDeferred(
194 res = self.get_success(
222195 self.handler.claim_one_time_keys(
223196 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
224197 )
230203
231204 # if the user uploads a one-time key, the next claim should fetch the
232205 # one-time key, and then go back to the fallback
233 yield defer.ensureDeferred(
206 self.get_success(
234207 self.handler.upload_keys_for_user(
235208 local_user, device_id, {"one_time_keys": otk}
236209 )
237210 )
238211
239 res = yield defer.ensureDeferred(
212 res = self.get_success(
240213 self.handler.claim_one_time_keys(
241214 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
242215 )
243216 )
244217 self.assertEqual(
245 res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
246 )
247
248 res = yield defer.ensureDeferred(
218 res,
219 {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
220 )
221
222 res = self.get_success(
249223 self.handler.claim_one_time_keys(
250224 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
251225 )
255229 {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
256230 )
257231
258 @defer.inlineCallbacks
259232 def test_replace_master_key(self):
260233 """uploading a new signing key should make the old signing key unavailable"""
261234 local_user = "@boris:" + self.hs.hostname
269242 },
270243 }
271244 }
272 yield defer.ensureDeferred(
273 self.handler.upload_signing_keys_for_user(local_user, keys1)
274 )
245 self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
275246
276247 keys2 = {
277248 "master_key": {
283254 },
284255 }
285256 }
286 yield defer.ensureDeferred(
287 self.handler.upload_signing_keys_for_user(local_user, keys2)
288 )
289
290 devices = yield defer.ensureDeferred(
257 self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
258
259 devices = self.get_success(
291260 self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
292261 )
293262 self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
294263
295 @defer.inlineCallbacks
296264 def test_reupload_signatures(self):
297265 """re-uploading a signature should not fail"""
298266 local_user = "@boris:" + self.hs.hostname
325293 "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
326294 "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
327295 )
328 yield defer.ensureDeferred(
329 self.handler.upload_signing_keys_for_user(local_user, keys1)
330 )
296 self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
331297
332298 # upload two device keys, which will be signed later by the self-signing key
333299 device_key_1 = {
357323 "signatures": {local_user: {"ed25519:def": "base64+signature"}},
358324 }
359325
360 yield defer.ensureDeferred(
326 self.get_success(
361327 self.handler.upload_keys_for_user(
362328 local_user, "abc", {"device_keys": device_key_1}
363329 )
364330 )
365 yield defer.ensureDeferred(
331 self.get_success(
366332 self.handler.upload_keys_for_user(
367333 local_user, "def", {"device_keys": device_key_2}
368334 )
371337 # sign the first device key and upload it
372338 del device_key_1["signatures"]
373339 sign.sign_json(device_key_1, local_user, signing_key)
374 yield defer.ensureDeferred(
340 self.get_success(
375341 self.handler.upload_signatures_for_device_keys(
376342 local_user, {local_user: {"abc": device_key_1}}
377343 )
382348 # signature for it
383349 del device_key_2["signatures"]
384350 sign.sign_json(device_key_2, local_user, signing_key)
385 yield defer.ensureDeferred(
351 self.get_success(
386352 self.handler.upload_signatures_for_device_keys(
387353 local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
388354 )
390356
391357 device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
392358 device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
393 devices = yield defer.ensureDeferred(
359 devices = self.get_success(
394360 self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
395361 )
396362 del devices["device_keys"][local_user]["abc"]["unsigned"]
398364 self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
399365 self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
400366
401 @defer.inlineCallbacks
402367 def test_self_signing_key_doesnt_show_up_as_device(self):
403368 """signing keys should be hidden when fetching a user's devices"""
404369 local_user = "@boris:" + self.hs.hostname
412377 },
413378 }
414379 }
415 yield defer.ensureDeferred(
416 self.handler.upload_signing_keys_for_user(local_user, keys1)
417 )
418
419 res = None
420 try:
421 yield defer.ensureDeferred(
422 self.hs.get_device_handler().check_device_registered(
423 user_id=local_user,
424 device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
425 initial_device_display_name="new display name",
426 )
427 )
428 except errors.SynapseError as e:
429 res = e.code
380 self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
381
382 e = self.get_failure(
383 self.hs.get_device_handler().check_device_registered(
384 user_id=local_user,
385 device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
386 initial_device_display_name="new display name",
387 ),
388 SynapseError,
389 )
390 res = e.value.code
430391 self.assertEqual(res, 400)
431392
432 res = yield defer.ensureDeferred(
433 self.handler.query_local_devices({local_user: None})
434 )
393 res = self.get_success(self.handler.query_local_devices({local_user: None}))
435394 self.assertDictEqual(res, {local_user: {}})
436395
437 @defer.inlineCallbacks
438396 def test_upload_signatures(self):
439397 """should check signatures that are uploaded"""
440398 # set up a user with cross-signing keys and a device. This user will
457415 "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
458416 )
459417
460 yield defer.ensureDeferred(
418 self.get_success(
461419 self.handler.upload_keys_for_user(
462420 local_user, device_id, {"device_keys": device_key}
463421 )
500458 "user_signing_key": usersigning_key,
501459 "self_signing_key": selfsigning_key,
502460 }
503 yield defer.ensureDeferred(
461 self.get_success(
504462 self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
505463 )
506464
514472 "usage": ["master"],
515473 "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
516474 }
517 yield defer.ensureDeferred(
475 self.get_success(
518476 self.handler.upload_signing_keys_for_user(
519477 other_user, {"master_key": other_master_key}
520478 )
521479 )
522480
523481 # test various signature failures (see below)
524 ret = yield defer.ensureDeferred(
482 ret = self.get_success(
525483 self.handler.upload_signatures_for_device_keys(
526484 local_user,
527485 {
601559 )
602560
603561 user_failures = ret["failures"][local_user]
604 self.assertEqual(
605 user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE
606 )
607 self.assertEqual(
608 user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE
609 )
610 self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND)
562 self.assertEqual(user_failures[device_id]["errcode"], Codes.INVALID_SIGNATURE)
563 self.assertEqual(
564 user_failures[master_pubkey]["errcode"], Codes.INVALID_SIGNATURE
565 )
566 self.assertEqual(user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
611567
612568 other_user_failures = ret["failures"][other_user]
613 self.assertEqual(
614 other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND
615 )
616 self.assertEqual(
617 other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN
569 self.assertEqual(other_user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
570 self.assertEqual(
571 other_user_failures[other_master_pubkey]["errcode"], Codes.UNKNOWN
618572 )
619573
620574 # test successful signatures
622576 sign.sign_json(device_key, local_user, selfsigning_signing_key)
623577 sign.sign_json(master_key, local_user, device_signing_key)
624578 sign.sign_json(other_master_key, local_user, usersigning_signing_key)
625 ret = yield defer.ensureDeferred(
579 ret = self.get_success(
626580 self.handler.upload_signatures_for_device_keys(
627581 local_user,
628582 {
635589 self.assertEqual(ret["failures"], {})
636590
637591 # fetch the signed keys/devices and make sure that the signatures are there
638 ret = yield defer.ensureDeferred(
592 ret = self.get_success(
639593 self.handler.query_devices(
640594 {"device_keys": {local_user: [], other_user: []}}, 0, local_user
641595 )
1818
1919 import mock
2020
21 from twisted.internet import defer
22
23 import synapse.api.errors
24 import synapse.handlers.e2e_room_keys
25 import synapse.storage
26 from synapse.api import errors
27
28 from tests import unittest, utils
21 from synapse.api.errors import SynapseError
22
23 from tests import unittest
2924
3025 # sample room_key data for use in the tests
3126 room_keys = {
4439 }
4540
4641
47 class E2eRoomKeysHandlerTestCase(unittest.TestCase):
48 def __init__(self, *args, **kwargs):
49 super().__init__(*args, **kwargs)
50 self.hs = None # type: synapse.server.HomeServer
51 self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
52
53 @defer.inlineCallbacks
54 def setUp(self):
55 self.hs = yield utils.setup_test_homeserver(
56 self.addCleanup, replication_layer=mock.Mock()
57 )
58 self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
59 self.local_user = "@boris:" + self.hs.hostname
60
61 @defer.inlineCallbacks
42 class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
43 def make_homeserver(self, reactor, clock):
44 return self.setup_test_homeserver(replication_layer=mock.Mock())
45
46 def prepare(self, reactor, clock, hs):
47 self.handler = hs.get_e2e_room_keys_handler()
48 self.local_user = "@boris:" + hs.hostname
49
6250 def test_get_missing_current_version_info(self):
6351 """Check that we get a 404 if we ask for info about the current version
6452 if there is no version.
6553 """
66 res = None
67 try:
68 yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
69 except errors.SynapseError as e:
70 res = e.code
71 self.assertEqual(res, 404)
72
73 @defer.inlineCallbacks
54 e = self.get_failure(
55 self.handler.get_version_info(self.local_user), SynapseError
56 )
57 res = e.value.code
58 self.assertEqual(res, 404)
59
7460 def test_get_missing_version_info(self):
7561 """Check that we get a 404 if we ask for info about a specific version
7662 if it doesn't exist.
7763 """
78 res = None
79 try:
80 yield defer.ensureDeferred(
81 self.handler.get_version_info(self.local_user, "bogus_version")
82 )
83 except errors.SynapseError as e:
84 res = e.code
85 self.assertEqual(res, 404)
86
87 @defer.inlineCallbacks
64 e = self.get_failure(
65 self.handler.get_version_info(self.local_user, "bogus_version"),
66 SynapseError,
67 )
68 res = e.value.code
69 self.assertEqual(res, 404)
70
8871 def test_create_version(self):
89 """Check that we can create and then retrieve versions.
90 """
91 res = yield defer.ensureDeferred(
72 """Check that we can create and then retrieve versions."""
73 res = self.get_success(
9274 self.handler.create_version(
9375 self.local_user,
9476 {
10082 self.assertEqual(res, "1")
10183
10284 # check we can retrieve it as the current version
103 res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
85 res = self.get_success(self.handler.get_version_info(self.local_user))
10486 version_etag = res["etag"]
10587 self.assertIsInstance(version_etag, str)
10688 del res["etag"]
11597 )
11698
11799 # check we can retrieve it as a specific version
118 res = yield defer.ensureDeferred(
119 self.handler.get_version_info(self.local_user, "1")
120 )
100 res = self.get_success(self.handler.get_version_info(self.local_user, "1"))
121101 self.assertEqual(res["etag"], version_etag)
122102 del res["etag"]
123103 self.assertDictEqual(
131111 )
132112
133113 # upload a new one...
134 res = yield defer.ensureDeferred(
114 res = self.get_success(
135115 self.handler.create_version(
136116 self.local_user,
137117 {
143123 self.assertEqual(res, "2")
144124
145125 # check we can retrieve it as the current version
146 res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
126 res = self.get_success(self.handler.get_version_info(self.local_user))
147127 del res["etag"]
148128 self.assertDictEqual(
149129 res,
155135 },
156136 )
157137
158 @defer.inlineCallbacks
159138 def test_update_version(self):
160 """Check that we can update versions.
161 """
162 version = yield defer.ensureDeferred(
163 self.handler.create_version(
164 self.local_user,
165 {
166 "algorithm": "m.megolm_backup.v1",
167 "auth_data": "first_version_auth_data",
168 },
169 )
170 )
171 self.assertEqual(version, "1")
172
173 res = yield defer.ensureDeferred(
139 """Check that we can update versions."""
140 version = self.get_success(
141 self.handler.create_version(
142 self.local_user,
143 {
144 "algorithm": "m.megolm_backup.v1",
145 "auth_data": "first_version_auth_data",
146 },
147 )
148 )
149 self.assertEqual(version, "1")
150
151 res = self.get_success(
174152 self.handler.update_version(
175153 self.local_user,
176154 version,
184162 self.assertDictEqual(res, {})
185163
186164 # check we can retrieve it as the current version
187 res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
165 res = self.get_success(self.handler.get_version_info(self.local_user))
188166 del res["etag"]
189167 self.assertDictEqual(
190168 res,
196174 },
197175 )
198176
199 @defer.inlineCallbacks
200177 def test_update_missing_version(self):
201 """Check that we get a 404 on updating nonexistent versions
202 """
203 res = None
204 try:
205 yield defer.ensureDeferred(
206 self.handler.update_version(
207 self.local_user,
208 "1",
209 {
210 "algorithm": "m.megolm_backup.v1",
211 "auth_data": "revised_first_version_auth_data",
212 "version": "1",
213 },
214 )
215 )
216 except errors.SynapseError as e:
217 res = e.code
218 self.assertEqual(res, 404)
219
220 @defer.inlineCallbacks
178 """Check that we get a 404 on updating nonexistent versions"""
179 e = self.get_failure(
180 self.handler.update_version(
181 self.local_user,
182 "1",
183 {
184 "algorithm": "m.megolm_backup.v1",
185 "auth_data": "revised_first_version_auth_data",
186 "version": "1",
187 },
188 ),
189 SynapseError,
190 )
191 res = e.value.code
192 self.assertEqual(res, 404)
193
221194 def test_update_omitted_version(self):
222 """Check that the update succeeds if the version is missing from the body
223 """
224 version = yield defer.ensureDeferred(
225 self.handler.create_version(
226 self.local_user,
227 {
228 "algorithm": "m.megolm_backup.v1",
229 "auth_data": "first_version_auth_data",
230 },
231 )
232 )
233 self.assertEqual(version, "1")
234
235 yield defer.ensureDeferred(
195 """Check that the update succeeds if the version is missing from the body"""
196 version = self.get_success(
197 self.handler.create_version(
198 self.local_user,
199 {
200 "algorithm": "m.megolm_backup.v1",
201 "auth_data": "first_version_auth_data",
202 },
203 )
204 )
205 self.assertEqual(version, "1")
206
207 self.get_success(
236208 self.handler.update_version(
237209 self.local_user,
238210 version,
244216 )
245217
246218 # check we can retrieve it as the current version
247 res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
219 res = self.get_success(self.handler.get_version_info(self.local_user))
248220 del res["etag"] # etag is opaque, so don't test its contents
249221 self.assertDictEqual(
250222 res,
256228 },
257229 )
258230
259 @defer.inlineCallbacks
260231 def test_update_bad_version(self):
261 """Check that we get a 400 if the version in the body doesn't match
262 """
263 version = yield defer.ensureDeferred(
264 self.handler.create_version(
265 self.local_user,
266 {
267 "algorithm": "m.megolm_backup.v1",
268 "auth_data": "first_version_auth_data",
269 },
270 )
271 )
272 self.assertEqual(version, "1")
273
274 res = None
275 try:
276 yield defer.ensureDeferred(
277 self.handler.update_version(
278 self.local_user,
279 version,
280 {
281 "algorithm": "m.megolm_backup.v1",
282 "auth_data": "revised_first_version_auth_data",
283 "version": "incorrect",
284 },
285 )
286 )
287 except errors.SynapseError as e:
288 res = e.code
232 """Check that we get a 400 if the version in the body doesn't match"""
233 version = self.get_success(
234 self.handler.create_version(
235 self.local_user,
236 {
237 "algorithm": "m.megolm_backup.v1",
238 "auth_data": "first_version_auth_data",
239 },
240 )
241 )
242 self.assertEqual(version, "1")
243
244 e = self.get_failure(
245 self.handler.update_version(
246 self.local_user,
247 version,
248 {
249 "algorithm": "m.megolm_backup.v1",
250 "auth_data": "revised_first_version_auth_data",
251 "version": "incorrect",
252 },
253 ),
254 SynapseError,
255 )
256 res = e.value.code
289257 self.assertEqual(res, 400)
290258
291 @defer.inlineCallbacks
292259 def test_delete_missing_version(self):
293 """Check that we get a 404 on deleting nonexistent versions
294 """
295 res = None
296 try:
297 yield defer.ensureDeferred(
298 self.handler.delete_version(self.local_user, "1")
299 )
300 except errors.SynapseError as e:
301 res = e.code
302 self.assertEqual(res, 404)
303
304 @defer.inlineCallbacks
260 """Check that we get a 404 on deleting nonexistent versions"""
261 e = self.get_failure(
262 self.handler.delete_version(self.local_user, "1"), SynapseError
263 )
264 res = e.value.code
265 self.assertEqual(res, 404)
266
305267 def test_delete_missing_current_version(self):
306 """Check that we get a 404 on deleting nonexistent current version
307 """
308 res = None
309 try:
310 yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
311 except errors.SynapseError as e:
312 res = e.code
313 self.assertEqual(res, 404)
314
315 @defer.inlineCallbacks
268 """Check that we get a 404 on deleting nonexistent current version"""
269 e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
270 res = e.value.code
271 self.assertEqual(res, 404)
272
316273 def test_delete_version(self):
317 """Check that we can create and then delete versions.
318 """
319 res = yield defer.ensureDeferred(
274 """Check that we can create and then delete versions."""
275 res = self.get_success(
320276 self.handler.create_version(
321277 self.local_user,
322278 {
328284 self.assertEqual(res, "1")
329285
330286 # check we can delete it
331 yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
287 self.get_success(self.handler.delete_version(self.local_user, "1"))
332288
333289 # check that it's gone
334 res = None
335 try:
336 yield defer.ensureDeferred(
337 self.handler.get_version_info(self.local_user, "1")
338 )
339 except errors.SynapseError as e:
340 res = e.code
341 self.assertEqual(res, 404)
342
343 @defer.inlineCallbacks
290 e = self.get_failure(
291 self.handler.get_version_info(self.local_user, "1"), SynapseError
292 )
293 res = e.value.code
294 self.assertEqual(res, 404)
295
344296 def test_get_missing_backup(self):
345 """Check that we get a 404 on querying missing backup
346 """
347 res = None
348 try:
349 yield defer.ensureDeferred(
350 self.handler.get_room_keys(self.local_user, "bogus_version")
351 )
352 except errors.SynapseError as e:
353 res = e.code
354 self.assertEqual(res, 404)
355
356 @defer.inlineCallbacks
297 """Check that we get a 404 on querying missing backup"""
298 e = self.get_failure(
299 self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
300 )
301 res = e.value.code
302 self.assertEqual(res, 404)
303
357304 def test_get_missing_room_keys(self):
358 """Check we get an empty response from an empty backup
359 """
360 version = yield defer.ensureDeferred(
361 self.handler.create_version(
362 self.local_user,
363 {
364 "algorithm": "m.megolm_backup.v1",
365 "auth_data": "first_version_auth_data",
366 },
367 )
368 )
369 self.assertEqual(version, "1")
370
371 res = yield defer.ensureDeferred(
372 self.handler.get_room_keys(self.local_user, version)
373 )
305 """Check we get an empty response from an empty backup"""
306 version = self.get_success(
307 self.handler.create_version(
308 self.local_user,
309 {
310 "algorithm": "m.megolm_backup.v1",
311 "auth_data": "first_version_auth_data",
312 },
313 )
314 )
315 self.assertEqual(version, "1")
316
317 res = self.get_success(self.handler.get_room_keys(self.local_user, version))
374318 self.assertDictEqual(res, {"rooms": {}})
375319
376320 # TODO: test the locking semantics when uploading room_keys,
377321 # although this is probably best done in sytest
378322
379 @defer.inlineCallbacks
380323 def test_upload_room_keys_no_versions(self):
381 """Check that we get a 404 on uploading keys when no versions are defined
382 """
383 res = None
384 try:
385 yield defer.ensureDeferred(
386 self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
387 )
388 except errors.SynapseError as e:
389 res = e.code
390 self.assertEqual(res, 404)
391
392 @defer.inlineCallbacks
324 """Check that we get a 404 on uploading keys when no versions are defined"""
325 e = self.get_failure(
326 self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
327 SynapseError,
328 )
329 res = e.value.code
330 self.assertEqual(res, 404)
331
393332 def test_upload_room_keys_bogus_version(self):
394333 """Check that we get a 404 on uploading keys when an nonexistent version
395334 is specified
396335 """
397 version = yield defer.ensureDeferred(
398 self.handler.create_version(
399 self.local_user,
400 {
401 "algorithm": "m.megolm_backup.v1",
402 "auth_data": "first_version_auth_data",
403 },
404 )
405 )
406 self.assertEqual(version, "1")
407
408 res = None
409 try:
410 yield defer.ensureDeferred(
411 self.handler.upload_room_keys(
412 self.local_user, "bogus_version", room_keys
413 )
414 )
415 except errors.SynapseError as e:
416 res = e.code
417 self.assertEqual(res, 404)
418
419 @defer.inlineCallbacks
336 version = self.get_success(
337 self.handler.create_version(
338 self.local_user,
339 {
340 "algorithm": "m.megolm_backup.v1",
341 "auth_data": "first_version_auth_data",
342 },
343 )
344 )
345 self.assertEqual(version, "1")
346
347 e = self.get_failure(
348 self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys),
349 SynapseError,
350 )
351 res = e.value.code
352 self.assertEqual(res, 404)
353
420354 def test_upload_room_keys_wrong_version(self):
421 """Check that we get a 403 on uploading keys for an old version
422 """
423 version = yield defer.ensureDeferred(
424 self.handler.create_version(
425 self.local_user,
426 {
427 "algorithm": "m.megolm_backup.v1",
428 "auth_data": "first_version_auth_data",
429 },
430 )
431 )
432 self.assertEqual(version, "1")
433
434 version = yield defer.ensureDeferred(
355 """Check that we get a 403 on uploading keys for an old version"""
356 version = self.get_success(
357 self.handler.create_version(
358 self.local_user,
359 {
360 "algorithm": "m.megolm_backup.v1",
361 "auth_data": "first_version_auth_data",
362 },
363 )
364 )
365 self.assertEqual(version, "1")
366
367 version = self.get_success(
435368 self.handler.create_version(
436369 self.local_user,
437370 {
442375 )
443376 self.assertEqual(version, "2")
444377
445 res = None
446 try:
447 yield defer.ensureDeferred(
448 self.handler.upload_room_keys(self.local_user, "1", room_keys)
449 )
450 except errors.SynapseError as e:
451 res = e.code
378 e = self.get_failure(
379 self.handler.upload_room_keys(self.local_user, "1", room_keys), SynapseError
380 )
381 res = e.value.code
452382 self.assertEqual(res, 403)
453383
454 @defer.inlineCallbacks
455384 def test_upload_room_keys_insert(self):
456 """Check that we can insert and retrieve keys for a session
457 """
458 version = yield defer.ensureDeferred(
459 self.handler.create_version(
460 self.local_user,
461 {
462 "algorithm": "m.megolm_backup.v1",
463 "auth_data": "first_version_auth_data",
464 },
465 )
466 )
467 self.assertEqual(version, "1")
468
469 yield defer.ensureDeferred(
385 """Check that we can insert and retrieve keys for a session"""
386 version = self.get_success(
387 self.handler.create_version(
388 self.local_user,
389 {
390 "algorithm": "m.megolm_backup.v1",
391 "auth_data": "first_version_auth_data",
392 },
393 )
394 )
395 self.assertEqual(version, "1")
396
397 self.get_success(
470398 self.handler.upload_room_keys(self.local_user, version, room_keys)
471399 )
472400
473 res = yield defer.ensureDeferred(
474 self.handler.get_room_keys(self.local_user, version)
475 )
401 res = self.get_success(self.handler.get_room_keys(self.local_user, version))
476402 self.assertDictEqual(res, room_keys)
477403
478404 # check getting room_keys for a given room
479 res = yield defer.ensureDeferred(
405 res = self.get_success(
480406 self.handler.get_room_keys(
481407 self.local_user, version, room_id="!abc:matrix.org"
482408 )
484410 self.assertDictEqual(res, room_keys)
485411
486412 # check getting room_keys for a given session_id
487 res = yield defer.ensureDeferred(
413 res = self.get_success(
488414 self.handler.get_room_keys(
489415 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
490416 )
491417 )
492418 self.assertDictEqual(res, room_keys)
493419
494 @defer.inlineCallbacks
495420 def test_upload_room_keys_merge(self):
496421 """Check that we can upload a new room_key for an existing session and
497422 have it correctly merged"""
498 version = yield defer.ensureDeferred(
499 self.handler.create_version(
500 self.local_user,
501 {
502 "algorithm": "m.megolm_backup.v1",
503 "auth_data": "first_version_auth_data",
504 },
505 )
506 )
507 self.assertEqual(version, "1")
508
509 yield defer.ensureDeferred(
423 version = self.get_success(
424 self.handler.create_version(
425 self.local_user,
426 {
427 "algorithm": "m.megolm_backup.v1",
428 "auth_data": "first_version_auth_data",
429 },
430 )
431 )
432 self.assertEqual(version, "1")
433
434 self.get_success(
510435 self.handler.upload_room_keys(self.local_user, version, room_keys)
511436 )
512437
513438 # get the etag to compare to future versions
514 res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
439 res = self.get_success(self.handler.get_version_info(self.local_user))
515440 backup_etag = res["etag"]
516441 self.assertEqual(res["count"], 1)
517442
521446 # test that increasing the message_index doesn't replace the existing session
522447 new_room_key["first_message_index"] = 2
523448 new_room_key["session_data"] = "new"
524 yield defer.ensureDeferred(
449 self.get_success(
525450 self.handler.upload_room_keys(self.local_user, version, new_room_keys)
526451 )
527452
528 res = yield defer.ensureDeferred(
529 self.handler.get_room_keys(self.local_user, version)
530 )
453 res = self.get_success(self.handler.get_room_keys(self.local_user, version))
531454 self.assertEqual(
532455 res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
533456 "SSBBTSBBIEZJU0gK",
534457 )
535458
536459 # the etag should be the same since the session did not change
537 res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
460 res = self.get_success(self.handler.get_version_info(self.local_user))
538461 self.assertEqual(res["etag"], backup_etag)
539462
540463 # test that marking the session as verified however /does/ replace it
541464 new_room_key["is_verified"] = True
542 yield defer.ensureDeferred(
465 self.get_success(
543466 self.handler.upload_room_keys(self.local_user, version, new_room_keys)
544467 )
545468
546 res = yield defer.ensureDeferred(
547 self.handler.get_room_keys(self.local_user, version)
548 )
469 res = self.get_success(self.handler.get_room_keys(self.local_user, version))
549470 self.assertEqual(
550471 res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
551472 )
552473
553474 # the etag should NOT be equal now, since the key changed
554 res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
475 res = self.get_success(self.handler.get_version_info(self.local_user))
555476 self.assertNotEqual(res["etag"], backup_etag)
556477 backup_etag = res["etag"]
557478
559480 # with a lower forwarding count
560481 new_room_key["forwarded_count"] = 2
561482 new_room_key["session_data"] = "other"
562 yield defer.ensureDeferred(
483 self.get_success(
563484 self.handler.upload_room_keys(self.local_user, version, new_room_keys)
564485 )
565486
566 res = yield defer.ensureDeferred(
567 self.handler.get_room_keys(self.local_user, version)
568 )
487 res = self.get_success(self.handler.get_room_keys(self.local_user, version))
569488 self.assertEqual(
570489 res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
571490 )
572491
573492 # the etag should be the same since the session did not change
574 res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
493 res = self.get_success(self.handler.get_version_info(self.local_user))
575494 self.assertEqual(res["etag"], backup_etag)
576495
577496 # TODO: check edge cases as well as the common variations here
578497
579 @defer.inlineCallbacks
580498 def test_delete_room_keys(self):
581 """Check that we can insert and delete keys for a session
582 """
583 version = yield defer.ensureDeferred(
499 """Check that we can insert and delete keys for a session"""
500 version = self.get_success(
584501 self.handler.create_version(
585502 self.local_user,
586503 {
592509 self.assertEqual(version, "1")
593510
594511 # check for bulk-delete
595 yield defer.ensureDeferred(
512 self.get_success(
596513 self.handler.upload_room_keys(self.local_user, version, room_keys)
597514 )
598 yield defer.ensureDeferred(
599 self.handler.delete_room_keys(self.local_user, version)
600 )
601 res = yield defer.ensureDeferred(
515 self.get_success(self.handler.delete_room_keys(self.local_user, version))
516 res = self.get_success(
602517 self.handler.get_room_keys(
603518 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
604519 )
606521 self.assertDictEqual(res, {"rooms": {}})
607522
608523 # check for bulk-delete per room
609 yield defer.ensureDeferred(
524 self.get_success(
610525 self.handler.upload_room_keys(self.local_user, version, room_keys)
611526 )
612 yield defer.ensureDeferred(
527 self.get_success(
613528 self.handler.delete_room_keys(
614529 self.local_user, version, room_id="!abc:matrix.org"
615530 )
616531 )
617 res = yield defer.ensureDeferred(
532 res = self.get_success(
618533 self.handler.get_room_keys(
619534 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
620535 )
622537 self.assertDictEqual(res, {"rooms": {}})
623538
624539 # check for bulk-delete per session
625 yield defer.ensureDeferred(
540 self.get_success(
626541 self.handler.upload_room_keys(self.local_user, version, room_keys)
627542 )
628 yield defer.ensureDeferred(
543 self.get_success(
629544 self.handler.delete_room_keys(
630545 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
631546 )
632547 )
633 res = yield defer.ensureDeferred(
548 res = self.get_success(
634549 self.handler.get_room_keys(
635550 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
636551 )
225225 for i in range(3):
226226 event = create_invite()
227227 self.get_success(
228 self.handler.on_invite_request(other_server, event, event.room_version,)
228 self.handler.on_invite_request(
229 other_server,
230 event,
231 event.room_version,
232 )
229233 )
230234
231235 event = create_invite()
232236 self.get_failure(
233 self.handler.on_invite_request(other_server, event, event.room_version,),
237 self.handler.on_invite_request(
238 other_server,
239 event,
240 event.room_version,
241 ),
234242 exc=LimitExceededError,
235243 )
236244
4343 self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
4444
4545 self.info = self.get_success(
46 self.hs.get_datastore().get_user_by_access_token(self.access_token,)
46 self.hs.get_datastore().get_user_by_access_token(
47 self.access_token,
48 )
4749 )
4850 self.token_id = self.info.token_id
4951
168170 self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
169171
170172 def test_allow_server_acl(self):
171 """Test that sending an ACL that blocks everyone but ourselves works.
172 """
173 """Test that sending an ACL that blocks everyone but ourselves works."""
173174
174175 self.helper.send_state(
175176 self.room_id,
180181 )
181182
182183 def test_deny_server_acl_block_outselves(self):
183 """Test that sending an ACL that blocks ourselves does not work.
184 """
184 """Test that sending an ACL that blocks ourselves does not work."""
185185 self.helper.send_state(
186186 self.room_id,
187187 EventTypes.ServerACL,
191191 )
192192
193193 def test_deny_redact_server_acl(self):
194 """Test that attempting to redact an ACL is blocked.
195 """
194 """Test that attempting to redact an ACL is blocked."""
196195
197196 body = self.helper.send_state(
198197 self.room_id,
2323 from synapse.server import HomeServer
2424 from synapse.types import UserID
2525
26 from tests.test_utils import FakeResponse, simple_async_mock
26 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
2727 from tests.unittest import HomeserverTestCase, override_config
2828
2929 try:
130130 return config
131131
132132 def make_homeserver(self, reactor, clock):
133
134133 self.http_client = Mock(spec=["get_json"])
135134 self.http_client.get_json.side_effect = get_json
136135 self.http_client.user_agent = "Synapse Test"
150149 return hs
151150
152151 def metadata_edit(self, values):
153 return patch.dict(self.provider._provider_metadata, values)
152 """Modify the result that will be returned by the well-known query"""
153
154 async def patched_get_json(uri):
155 res = await get_json(uri)
156 if uri == WELL_KNOWN:
157 res.update(values)
158 return res
159
160 return patch.object(self.http_client, "get_json", patched_get_json)
154161
155162 def assertRenderedError(self, error, error_description=None):
156163 self.render_error.assert_called_once()
211218 self.http_client.get_json.assert_called_once_with(JWKS_URI)
212219
213220 # Throw if the JWKS uri is missing
214 with self.metadata_edit({"jwks_uri": None}):
221 original = self.provider.load_metadata
222
223 async def patched_load_metadata():
224 m = (await original()).copy()
225 m.update({"jwks_uri": None})
226 return m
227
228 with patch.object(self.provider, "load_metadata", patched_load_metadata):
215229 self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
216230
217231 # Return empty key set if JWKS are not used
221235 self.http_client.get_json.assert_not_called()
222236 self.assertEqual(jwks, {"keys": []})
223237
224 @override_config({"oidc_config": COMMON_CONFIG})
225238 def test_validate_config(self):
226239 """Provider metadatas are extensively validated."""
227240 h = self.provider
228241
242 def force_load_metadata():
243 async def force_load():
244 return await h.load_metadata(force=True)
245
246 return get_awaitable_result(force_load())
247
229248 # Default test config does not throw
230 h._validate_metadata()
249 force_load_metadata()
231250
232251 with self.metadata_edit({"issuer": None}):
233 self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
252 self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
234253
235254 with self.metadata_edit({"issuer": "http://insecure/"}):
236 self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
255 self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
237256
238257 with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
239 self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
258 self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
240259
241260 with self.metadata_edit({"authorization_endpoint": None}):
242261 self.assertRaisesRegex(
243 ValueError, "authorization_endpoint", h._validate_metadata
262 ValueError, "authorization_endpoint", force_load_metadata
244263 )
245264
246265 with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
247266 self.assertRaisesRegex(
248 ValueError, "authorization_endpoint", h._validate_metadata
267 ValueError, "authorization_endpoint", force_load_metadata
249268 )
250269
251270 with self.metadata_edit({"token_endpoint": None}):
252 self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
271 self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
253272
254273 with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
255 self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
274 self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
256275
257276 with self.metadata_edit({"jwks_uri": None}):
258 self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
277 self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
259278
260279 with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
261 self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
280 self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
262281
263282 with self.metadata_edit({"response_types_supported": ["id_token"]}):
264283 self.assertRaisesRegex(
265 ValueError, "response_types_supported", h._validate_metadata
284 ValueError, "response_types_supported", force_load_metadata
266285 )
267286
268287 with self.metadata_edit(
269288 {"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
270289 ):
271290 # should not throw, as client_secret_basic is the default auth method
272 h._validate_metadata()
291 force_load_metadata()
273292
274293 with self.metadata_edit(
275294 {"token_endpoint_auth_methods_supported": ["client_secret_post"]}
277296 self.assertRaisesRegex(
278297 ValueError,
279298 "token_endpoint_auth_methods_supported",
280 h._validate_metadata,
299 force_load_metadata,
281300 )
282301
283302 # Tests for configs that require the userinfo endpoint
286305 h._user_profile_method = "userinfo_endpoint"
287306 self.assertTrue(h._uses_userinfo)
288307
289 # Revert the profile method and do not request the "openid" scope.
308 # Revert the profile method and do not request the "openid" scope: this should
309 # mean that we check for a userinfo endpoint
290310 h._user_profile_method = "auto"
291311 h._scopes = []
292312 self.assertTrue(h._uses_userinfo)
293 self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
294
295 with self.metadata_edit(
296 {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
297 ):
298 # Shouldn't raise with a valid userinfo, even without
299 h._validate_metadata()
313 with self.metadata_edit({"userinfo_endpoint": None}):
314 self.assertRaisesRegex(ValueError, "userinfo_endpoint", force_load_metadata)
315
316 with self.metadata_edit({"jwks_uri": None}):
317 # Shouldn't raise with a valid userinfo, even without jwks
318 force_load_metadata()
300319
301320 @override_config({"oidc_config": {"skip_verification": True}})
302321 def test_skip_verification(self):
303322 """Provider metadata validation can be disabled by config."""
304323 with self.metadata_edit({"issuer": "http://insecure"}):
305324 # This should not throw
306 self.provider._validate_metadata()
325 get_awaitable_result(self.provider.load_metadata())
307326
308327 def test_redirect_request(self):
309328 """The redirect request has the right arguments & generates a valid session cookie."""
310 req = Mock(spec=["addCookie"])
329 req = Mock(spec=["cookies"])
330 req.cookies = []
331
311332 url = self.get_success(
312333 self.provider.handle_redirect_request(req, b"http://client/redirect")
313334 )
326347 self.assertEqual(len(params["state"]), 1)
327348 self.assertEqual(len(params["nonce"]), 1)
328349
329 # Check what is in the cookie
330 # note: python3.5 mock does not have the .called_once() method
331 calls = req.addCookie.call_args_list
332 self.assertEqual(len(calls), 1) # called once
333 # For some reason, call.args does not work with python3.5
334 args = calls[0][0]
335 kwargs = calls[0][1]
350 # Check what is in the cookies
351 self.assertEqual(len(req.cookies), 2) # two cookies
352 cookie_header = req.cookies[0]
336353
337354 # The cookie name and path don't really matter, just that it has to be coherent
338355 # between the callback & redirect handlers.
339 self.assertEqual(args[0], b"oidc_session")
340 self.assertEqual(kwargs["path"], "/_synapse/client/oidc")
341 cookie = args[1]
356 parts = [p.strip() for p in cookie_header.split(b";")]
357 self.assertIn(b"Path=/_synapse/client/oidc", parts)
358 name, cookie = parts[0].split(b"=")
359 self.assertEqual(name, b"oidc_session")
342360
343361 macaroon = pymacaroons.Macaroon.deserialize(cookie)
344362 state = self.handler._token_generator._get_value_from_macaroon(
469487
470488 def test_callback_session(self):
471489 """The callback verifies the session presence and validity"""
472 request = Mock(spec=["args", "getCookie", "addCookie"])
490 request = Mock(spec=["args", "getCookie", "cookies"])
473491
474492 # Missing cookie
475493 request.args = {}
492510
493511 # Mismatching session
494512 session = self._generate_oidc_session_token(
495 state="state", nonce="nonce", client_redirect_url="http://client/redirect",
513 state="state",
514 nonce="nonce",
515 client_redirect_url="http://client/redirect",
496516 )
497517 request.args = {}
498518 request.args[b"state"] = [b"mismatching state"]
547567 # Internal server error with no JSON body
548568 self.http_client.request = simple_async_mock(
549569 return_value=FakeResponse(
550 code=500, phrase=b"Internal Server Error", body=b"Not JSON",
570 code=500,
571 phrase=b"Internal Server Error",
572 body=b"Not JSON",
551573 )
552574 )
553575 exc = self.get_failure(self.provider._exchange_code(code), OidcError)
567589
568590 # 4xx error without "error" field
569591 self.http_client.request = simple_async_mock(
570 return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
592 return_value=FakeResponse(
593 code=400,
594 phrase=b"Bad request",
595 body=b"{}",
596 )
571597 )
572598 exc = self.get_failure(self.provider._exchange_code(code), OidcError)
573599 self.assertEqual(exc.value.error, "server_error")
575601 # 2xx error with "error" field
576602 self.http_client.request = simple_async_mock(
577603 return_value=FakeResponse(
578 code=200, phrase=b"OK", body=b'{"error": "some_error"}',
604 code=200,
605 phrase=b"OK",
606 body=b'{"error": "some_error"}',
579607 )
580608 )
581609 exc = self.get_failure(self.provider._exchange_code(code), OidcError)
612640 state = "state"
613641 client_redirect_url = "http://client/redirect"
614642 session = self._generate_oidc_session_token(
615 state=state, nonce="nonce", client_redirect_url=client_redirect_url,
643 state=state,
644 nonce="nonce",
645 client_redirect_url=client_redirect_url,
616646 )
617647 request = _build_callback_request("code", state, session)
618648
875905 session = handler._token_generator.generate_oidc_session_token(
876906 state=state,
877907 session_data=OidcSessionData(
878 idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
908 idp_id="oidc",
909 nonce="nonce",
910 client_redirect_url=client_redirect_url,
879911 ),
880912 )
881913 request = _build_callback_request("code", state, session)
909941 spec=[
910942 "args",
911943 "getCookie",
912 "addCookie",
944 "cookies",
913945 "requestHeaders",
914946 "getClientIP",
915947 "getHeader",
916948 ]
917949 )
918950
951 request.cookies = []
919952 request.getCookie.return_value = session
920953 request.args = {}
921954 request.args[b"code"] = [code.encode("utf-8")]
230230 }
231231 )
232232 def test_no_local_user_fallback_login(self):
233 """localdb_enabled can block login with the local password
234 """
233 """localdb_enabled can block login with the local password"""
235234 self.register_user("localuser", "localpass")
236235
237236 # check_password must return an awaitable
250249 }
251250 )
252251 def test_no_local_user_fallback_ui_auth(self):
253 """localdb_enabled can block ui auth with the local password
254 """
252 """localdb_enabled can block ui auth with the local password"""
255253 self.register_user("localuser", "localpass")
256254
257255 # allow login via the auth provider
593591 )
594592
595593 def _delete_device(
596 self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
594 self,
595 access_token: str,
596 device: str,
597 body: Union[JsonDict, bytes] = b"",
597598 ) -> FakeChannel:
598599 """Delete an individual device."""
599600 channel = self.make_request(
588588 )
589589
590590 def _add_new_user(self, room_id, user_id):
591 """Add new user to the room by creating an event and poking the federation API.
592 """
591 """Add new user to the room by creating an event and poking the federation API."""
593592
594593 hostname = get_domain_from_id(user_id)
595594
1212 # See the License for the specific language governing permissions and
1313 # limitations under the License.
1414
15
1615 from mock import Mock
17
18 from twisted.internet import defer
1916
2017 import synapse.types
2118 from synapse.api.errors import AuthError, SynapseError
2320
2421 from tests import unittest
2522 from tests.test_utils import make_awaitable
26 from tests.utils import setup_test_homeserver
27
28
29 class ProfileTestCase(unittest.TestCase):
23
24
25 class ProfileTestCase(unittest.HomeserverTestCase):
3026 """ Tests profile management. """
3127
32 @defer.inlineCallbacks
33 def setUp(self):
28 def make_homeserver(self, reactor, clock):
3429 self.mock_federation = Mock()
3530 self.mock_registry = Mock()
3631
4136
4237 self.mock_registry.register_query_handler = register_query_handler
4338
44 hs = yield setup_test_homeserver(
45 self.addCleanup,
39 hs = self.setup_test_homeserver(
4640 federation_client=self.mock_federation,
4741 federation_server=Mock(),
4842 federation_registry=self.mock_registry,
4943 )
50
44 return hs
45
46 def prepare(self, reactor, clock, hs):
5147 self.store = hs.get_datastore()
5248
5349 self.frank = UserID.from_string("@1234ABCD:test")
5450 self.bob = UserID.from_string("@4567:test")
5551 self.alice = UserID.from_string("@alice:remote")
5652
57 yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
53 self.get_success(self.store.create_profile(self.frank.localpart))
5854
5955 self.handler = hs.get_profile_handler()
60 self.hs = hs
61
62 @defer.inlineCallbacks
56
6357 def test_get_my_name(self):
64 yield defer.ensureDeferred(
58 self.get_success(
6559 self.store.set_profile_displayname(self.frank.localpart, "Frank")
6660 )
6761
68 displayname = yield defer.ensureDeferred(
69 self.handler.get_displayname(self.frank)
70 )
62 displayname = self.get_success(self.handler.get_displayname(self.frank))
7163
7264 self.assertEquals("Frank", displayname)
7365
74 @defer.inlineCallbacks
7566 def test_set_my_name(self):
76 yield defer.ensureDeferred(
67 self.get_success(
7768 self.handler.set_displayname(
7869 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
7970 )
8172
8273 self.assertEquals(
8374 (
84 yield defer.ensureDeferred(
75 self.get_success(
8576 self.store.get_profile_displayname(self.frank.localpart)
8677 )
8778 ),
8980 )
9081
9182 # Set displayname again
92 yield defer.ensureDeferred(
83 self.get_success(
9384 self.handler.set_displayname(
9485 self.frank, synapse.types.create_requester(self.frank), "Frank"
9586 )
9788
9889 self.assertEquals(
9990 (
100 yield defer.ensureDeferred(
91 self.get_success(
10192 self.store.get_profile_displayname(self.frank.localpart)
10293 )
10394 ),
10596 )
10697
10798 # Set displayname to an empty string
108 yield defer.ensureDeferred(
99 self.get_success(
109100 self.handler.set_displayname(
110101 self.frank, synapse.types.create_requester(self.frank), ""
111102 )
112103 )
113104
114105 self.assertIsNone(
106 (self.get_success(self.store.get_profile_displayname(self.frank.localpart)))
107 )
108
109 def test_set_my_name_if_disabled(self):
110 self.hs.config.enable_set_displayname = False
111
112 # Setting displayname for the first time is allowed
113 self.get_success(
114 self.store.set_profile_displayname(self.frank.localpart, "Frank")
115 )
116
117 self.assertEquals(
115118 (
116 yield defer.ensureDeferred(
119 self.get_success(
117120 self.store.get_profile_displayname(self.frank.localpart)
118121 )
119 )
120 )
121
122 @defer.inlineCallbacks
123 def test_set_my_name_if_disabled(self):
124 self.hs.config.enable_set_displayname = False
125
126 # Setting displayname for the first time is allowed
127 yield defer.ensureDeferred(
128 self.store.set_profile_displayname(self.frank.localpart, "Frank")
129 )
130
131 self.assertEquals(
132 (
133 yield defer.ensureDeferred(
134 self.store.get_profile_displayname(self.frank.localpart)
135 )
136122 ),
137123 "Frank",
138124 )
139125
140126 # Setting displayname a second time is forbidden
141 d = defer.ensureDeferred(
127 self.get_failure(
142128 self.handler.set_displayname(
143129 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
144 )
145 )
146
147 yield self.assertFailure(d, SynapseError)
148
149 @defer.inlineCallbacks
130 ),
131 SynapseError,
132 )
133
150134 def test_set_my_name_noauth(self):
151 d = defer.ensureDeferred(
135 self.get_failure(
152136 self.handler.set_displayname(
153137 self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
154 )
155 )
156
157 yield self.assertFailure(d, AuthError)
158
159 @defer.inlineCallbacks
138 ),
139 AuthError,
140 )
141
160142 def test_get_other_name(self):
161143 self.mock_federation.make_query.return_value = make_awaitable(
162144 {"displayname": "Alice"}
163145 )
164146
165 displayname = yield defer.ensureDeferred(
166 self.handler.get_displayname(self.alice)
167 )
147 displayname = self.get_success(self.handler.get_displayname(self.alice))
168148
169149 self.assertEquals(displayname, "Alice")
170150 self.mock_federation.make_query.assert_called_with(
174154 ignore_backoff=True,
175155 )
176156
177 @defer.inlineCallbacks
178157 def test_incoming_fed_query(self):
179 yield defer.ensureDeferred(self.store.create_profile("caroline"))
180 yield defer.ensureDeferred(
181 self.store.set_profile_displayname("caroline", "Caroline")
182 )
183
184 response = yield defer.ensureDeferred(
158 self.get_success(self.store.create_profile("caroline"))
159 self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
160
161 response = self.get_success(
185162 self.query_handlers["profile"](
186163 {"user_id": "@caroline:test", "field": "displayname"}
187164 )
189166
190167 self.assertEquals({"displayname": "Caroline"}, response)
191168
192 @defer.inlineCallbacks
193169 def test_get_my_avatar(self):
194 yield defer.ensureDeferred(
170 self.get_success(
195171 self.store.set_profile_avatar_url(
196172 self.frank.localpart, "http://my.server/me.png"
197173 )
198174 )
199 avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
175 avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
200176
201177 self.assertEquals("http://my.server/me.png", avatar_url)
202178
203 @defer.inlineCallbacks
204179 def test_set_my_avatar(self):
205 yield defer.ensureDeferred(
180 self.get_success(
206181 self.handler.set_avatar_url(
207182 self.frank,
208183 synapse.types.create_requester(self.frank),
211186 )
212187
213188 self.assertEquals(
214 (
215 yield defer.ensureDeferred(
216 self.store.get_profile_avatar_url(self.frank.localpart)
217 )
218 ),
189 (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
219190 "http://my.server/pic.gif",
220191 )
221192
222193 # Set avatar again
223 yield defer.ensureDeferred(
194 self.get_success(
224195 self.handler.set_avatar_url(
225196 self.frank,
226197 synapse.types.create_requester(self.frank),
229200 )
230201
231202 self.assertEquals(
232 (
233 yield defer.ensureDeferred(
234 self.store.get_profile_avatar_url(self.frank.localpart)
235 )
236 ),
203 (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
237204 "http://my.server/me.png",
238205 )
239206
240207 # Set avatar to an empty string
241 yield defer.ensureDeferred(
242 self.handler.set_avatar_url(
243 self.frank, synapse.types.create_requester(self.frank), "",
208 self.get_success(
209 self.handler.set_avatar_url(
210 self.frank,
211 synapse.types.create_requester(self.frank),
212 "",
244213 )
245214 )
246215
247216 self.assertIsNone(
248 (
249 yield defer.ensureDeferred(
250 self.store.get_profile_avatar_url(self.frank.localpart)
251 )
252 ),
253 )
254
255 @defer.inlineCallbacks
217 (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
218 )
219
256220 def test_set_my_avatar_if_disabled(self):
257221 self.hs.config.enable_set_avatar_url = False
258222
259223 # Setting displayname for the first time is allowed
260 yield defer.ensureDeferred(
224 self.get_success(
261225 self.store.set_profile_avatar_url(
262226 self.frank.localpart, "http://my.server/me.png"
263227 )
264228 )
265229
266230 self.assertEquals(
267 (
268 yield defer.ensureDeferred(
269 self.store.get_profile_avatar_url(self.frank.localpart)
270 )
271 ),
231 (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
272232 "http://my.server/me.png",
273233 )
274234
275235 # Set avatar a second time is forbidden
276 d = defer.ensureDeferred(
236 self.get_failure(
277237 self.handler.set_avatar_url(
278238 self.frank,
279239 synapse.types.create_requester(self.frank),
280240 "http://my.server/pic.gif",
281 )
282 )
283
284 yield self.assertFailure(d, SynapseError)
241 ),
242 SynapseError,
243 )
258258 )
259259 self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
260260
261 @override_config(
262 {
263 "saml2_config": {
264 "attribute_requirements": [
265 {"attribute": "userGroup", "value": "staff"},
266 {"attribute": "department", "value": "sales"},
267 ],
268 },
269 }
270 )
271 def test_attribute_requirements(self):
272 """The required attributes must be met from the SAML response."""
273
274 # stub out the auth handler
275 auth_handler = self.hs.get_auth_handler()
276 auth_handler.complete_sso_login = simple_async_mock()
277
278 # The response doesn't have the proper userGroup or department.
279 saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
280 request = _mock_request()
281 self.get_success(
282 self.handler._handle_authn_response(request, saml_response, "redirect_uri")
283 )
284 auth_handler.complete_sso_login.assert_not_called()
285
286 # The response doesn't have the proper department.
287 saml_response = FakeAuthnResponse(
288 {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
289 )
290 request = _mock_request()
291 self.get_success(
292 self.handler._handle_authn_response(request, saml_response, "redirect_uri")
293 )
294 auth_handler.complete_sso_login.assert_not_called()
295
296 # Add the proper attributes and it should succeed.
297 saml_response = FakeAuthnResponse(
298 {
299 "uid": "test_user",
300 "username": "test_user",
301 "userGroup": ["staff", "admin"],
302 "department": ["sales"],
303 }
304 )
305 request.reset_mock()
306 self.get_success(
307 self.handler._handle_authn_response(request, saml_response, "redirect_uri")
308 )
309
310 # check that the auth handler got called as expected
311 auth_handler.complete_sso_login.assert_called_once_with(
312 "@test_user:test", request, "redirect_uri", None, new_user=True
313 )
314
261315
262316 def _mock_request():
263317 """Returns a mock which will stand in as a SynapseRequest"""
264 return Mock(spec=["getClientIP", "getHeader"])
318 return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
142142 self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
143143
144144 self.datastore.get_to_device_stream_token = lambda: 0
145 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
146 ([], 0)
147 )
148 self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
149 None
150 )
151 self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
152 None
145 self.datastore.get_new_device_msgs_for_remote = (
146 lambda *args, **kargs: make_awaitable(([], 0))
147 )
148 self.datastore.delete_device_msgs_for_remote = (
149 lambda *args, **kargs: make_awaitable(None)
150 )
151 self.datastore.set_received_txn_response = (
152 lambda *args, **kwargs: make_awaitable(None)
153153 )
154154
155155 def test_started_typing_local(self):
199199
200200 # Check that the room has an encryption state event
201201 event_content = self.helper.get_state(
202 room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
202 room_id=room_id,
203 event_type=EventTypes.RoomEncryption,
204 tok=user_token,
203205 )
204206 self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
205207
208210
209211 # Check that the room has an encryption state event
210212 event_content = self.helper.get_state(
211 room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
213 room_id=room_id,
214 event_type=EventTypes.RoomEncryption,
215 tok=user_token,
212216 )
213217 self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
214218
226230
227231 # Check that the room has an encryption state event
228232 event_content = self.helper.get_state(
229 room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
233 room_id=room_id,
234 event_type=EventTypes.RoomEncryption,
235 tok=user_token,
230236 )
231237 self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
232238
517517 self.successResultOf(test_d)
518518
519519 def test_get_well_known(self):
520 """Test the behaviour when the .well-known delegates elsewhere
521 """
520 """Test the behaviour when the .well-known delegates elsewhere"""
522521
523522 self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
524523 self.reactor.lookups["testserv"] = "1.2.3.4"
11341133 self.assertIsNone(r.delegated_server)
11351134
11361135 def test_srv_fallbacks(self):
1137 """Test that other SRV results are tried if the first one fails.
1138 """
1136 """Test that other SRV results are tried if the first one fails."""
11391137 self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
11401138 [
11411139 Server(host=b"target.com", port=8443),
1717
1818 from twisted.python.failure import Failure
1919 from twisted.web.client import ResponseDone
20 from twisted.web.iweb import UNKNOWN_LENGTH
2021
2122 from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
2223
2627 class ReadBodyWithMaxSizeTests(TestCase):
2728 def setUp(self):
2829 """Start reading the body, returns the response, result and proto"""
29 self.response = Mock()
30 response = Mock(length=UNKNOWN_LENGTH)
3031 self.result = BytesIO()
31 self.deferred = read_body_with_max_size(self.response, self.result, 6)
32 self.deferred = read_body_with_max_size(response, self.result, 6)
3233
3334 # Fish the protocol out of the response.
34 self.protocol = self.response.deliverBody.call_args[0][0]
35 self.protocol = response.deliverBody.call_args[0][0]
3536 self.protocol.transport = Mock()
3637
3738 def _cleanup_error(self):
8788 self.protocol.dataReceived(b"1234567890")
8889 self.assertIsInstance(self.deferred.result, Failure)
8990 self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
90 self.protocol.transport.loseConnection.assert_called_once()
91 self.protocol.transport.abortConnection.assert_called_once()
9192
9293 # More data might have come in.
9394 self.protocol.dataReceived(b"1234567890")
9999
100100 # Check that the event was sent
101101 self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
102 expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
102 expected_requester,
103 event_dict,
104 ratelimit=False,
105 ignore_shadow_ban=True,
103106 )
104107
105108 # Create and send a state event
123123 )
124124 self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
125125
126 # The other user sends some messages
126 # The other user sends a single message.
127 self.helper.send(room, body="Hi!", tok=self.others[0].token)
128
129 # We should get emailed about that message
130 self._check_for_mail()
131
132 # The other user sends multiple messages.
127133 self.helper.send(room, body="Hi!", tok=self.others[0].token)
128134 self.helper.send(room, body="There!", tok=self.others[0].token)
129135
130 # We should get emailed about that message
131136 self._check_for_mail()
132137
133138 def test_invite_sends_email(self):
216221 # We should get emailed about those messages
217222 self._check_for_mail()
218223
224 def test_empty_room(self):
225 """All users leaving a room shouldn't cause the pusher to break."""
226 # Create a simple room with two users
227 room = self.helper.create_room_as(self.user_id, tok=self.access_token)
228 self.helper.invite(
229 room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
230 )
231 self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
232
233 # The other user sends a single message.
234 self.helper.send(room, body="Hi!", tok=self.others[0].token)
235
236 # Leave the room before the message is processed.
237 self.helper.leave(room, self.user_id, tok=self.access_token)
238 self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
239
240 # We should get emailed about that message
241 self._check_for_mail()
242
243 def test_empty_room_multiple_messages(self):
244 """All users leaving a room shouldn't cause the pusher to break."""
245 # Create a simple room with two users
246 room = self.helper.create_room_as(self.user_id, tok=self.access_token)
247 self.helper.invite(
248 room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
249 )
250 self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
251
252 # The other user sends a single message.
253 self.helper.send(room, body="Hi!", tok=self.others[0].token)
254 self.helper.send(room, body="There!", tok=self.others[0].token)
255
256 # Leave the room before the message is processed.
257 self.helper.leave(room, self.user_id, tok=self.access_token)
258 self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
259
260 # We should get emailed about that message
261 self._check_for_mail()
262
219263 def test_encrypted_message(self):
220264 room = self.helper.create_room_as(self.user_id, tok=self.access_token)
221265 self.helper.invite(
268312 pushers = list(pushers)
269313 self.assertEqual(len(pushers), 1)
270314 self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
315
316 # Reset the attempts.
317 self.email_attempts = []
7878
7979 repl_handler = ReplicationCommandHandler(self.worker_hs)
8080 self.client = ClientReplicationStreamProtocol(
81 self.worker_hs, "client", "test", clock, repl_handler,
81 self.worker_hs,
82 "client",
83 "test",
84 clock,
85 repl_handler,
8286 )
8387
8488 self._client_transport = None
227231 if self.hs.config.redis.redis_enabled:
228232 # Handle attempts to connect to fake redis server.
229233 self.reactor.add_tcp_client_callback(
230 "localhost", 6379, self.connect_any_redis_attempts,
234 "localhost",
235 6379,
236 self.connect_any_redis_attempts,
231237 )
232238
233239 self.hs.get_tcp_replication().start_replication(self.hs)
245251 )
246252
247253 def create_test_resource(self):
248 """Overrides `HomeserverTestCase.create_test_resource`.
249 """
254 """Overrides `HomeserverTestCase.create_test_resource`."""
250255 # We override this so that it automatically registers all the HTTP
251256 # replication servlets, without having to explicitly do that in all
252257 # subclassses.
295300 if instance_loc.host not in self.reactor.lookups:
296301 raise Exception(
297302 "Host does not have an IP for instance_map[%r].host = %r"
298 % (instance_name, instance_loc.host,)
303 % (
304 instance_name,
305 instance_loc.host,
306 )
299307 )
300308
301309 self.reactor.add_tcp_client_callback(
314322 if not worker_hs.config.redis_enabled:
315323 repl_handler = ReplicationCommandHandler(worker_hs)
316324 client = ClientReplicationStreamProtocol(
317 worker_hs, "client", "test", self.clock, repl_handler,
325 worker_hs,
326 "client",
327 "test",
328 self.clock,
329 repl_handler,
318330 )
319331 server = self.server_factory.buildProtocol(None)
320332
484496 self._pull_to_push_producer.stop()
485497
486498 def checkPersistence(self, request, version):
487 """Check whether the connection can be re-used
488 """
499 """Check whether the connection can be re-used"""
489500 # We hijack this to always say no for ease of wiring stuff up in
490501 # `handle_http_replication_attempt`.
491502 request.responseHeaders.setRawHeaders(b"connection", [b"close"])
493504
494505
495506 class _PullToPushProducer:
496 """A push producer that wraps a pull producer.
497 """
507 """A push producer that wraps a pull producer."""
498508
499509 def __init__(
500510 self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
511521 self._start_loop()
512522
513523 def _start_loop(self):
514 """Start the looping call to
515 """
524 """Start the looping call to"""
516525
517526 if not self._looping_call:
518527 # Start a looping call which runs every tick.
519528 self._looping_call = self._clock.looping_call(self._run_once, 0)
520529
521530 def stop(self):
522 """Stops calling resumeProducing.
523 """
531 """Stops calling resumeProducing."""
524532 if self._looping_call:
525533 self._looping_call.stop()
526534 self._looping_call = None
527535
528536 def pauseProducing(self):
529 """Implements IPushProducer
530 """
537 """Implements IPushProducer"""
531538 self.stop()
532539
533540 def resumeProducing(self):
534 """Implements IPushProducer
535 """
541 """Implements IPushProducer"""
536542 self._start_loop()
537543
538544 def stopProducing(self):
539 """Implements IPushProducer
540 """
545 """Implements IPushProducer"""
541546 self.stop()
542547 self._producer.stopProducing()
543548
544549 def _run_once(self):
545 """Calls resumeProducing on producer once.
546 """
550 """Calls resumeProducing on producer once."""
547551
548552 try:
549553 self._producer.resumeProducing()
558562
559563
560564 class FakeRedisPubSubServer:
561 """A fake Redis server for pub/sub.
562 """
565 """A fake Redis server for pub/sub."""
563566
564567 def __init__(self):
565568 self._subscribers = set()
566569
567570 def add_subscriber(self, conn):
568 """A connection has called SUBSCRIBE
569 """
571 """A connection has called SUBSCRIBE"""
570572 self._subscribers.add(conn)
571573
572574 def remove_subscriber(self, conn):
573 """A connection has called UNSUBSCRIBE
574 """
575 """A connection has called UNSUBSCRIBE"""
575576 self._subscribers.discard(conn)
576577
577578 def publish(self, conn, channel, msg) -> int:
578 """A connection want to publish a message to subscribers.
579 """
579 """A connection want to publish a message to subscribers."""
580580 for sub in self._subscribers:
581581 sub.send(["message", channel, msg])
582582
587587
588588
589589 class FakeRedisPubSubProtocol(Protocol):
590 """A connection from a client talking to the fake Redis server.
591 """
590 """A connection from a client talking to the fake Redis server."""
592591
593592 def __init__(self, server: FakeRedisPubSubServer):
594593 self._server = server
612611 self.handle_command(msg[0], *msg[1:])
613612
614613 def handle_command(self, command, *args):
615 """Received a Redis command from the client.
616 """
614 """Received a Redis command from the client."""
617615
618616 # We currently only support pub/sub.
619617 if command == b"PUBLISH":
634632 raise Exception("Unknown command")
635633
636634 def send(self, msg):
637 """Send a message back to the client.
638 """
635 """Send a message back to the client."""
639636 raw = self.encode(msg).encode("utf-8")
640637
641638 self.transport.write(raw)
6565
6666 self.get_success(
6767 self.master_store.store_room(
68 ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1,
68 ROOM_ID,
69 USER_ID,
70 is_public=False,
71 room_version=RoomVersions.V1,
6972 )
7073 )
7174
2222
2323 class AccountDataStreamTestCase(BaseStreamTestCase):
2424 def test_update_function_room_account_data_limit(self):
25 """Test replication with many room account data updates
26 """
25 """Test replication with many room account data updates"""
2726 store = self.hs.get_datastore()
2827
2928 # generate lots of account data updates
6968 self.assertEqual([], received_rows)
7069
7170 def test_update_function_global_account_data_limit(self):
72 """Test replication with many global account data updates
73 """
71 """Test replication with many global account data updates"""
7472 store = self.hs.get_datastore()
7573
7674 # generate lots of account data updates
128128 )
129129 pls["users"][OTHER_USER] = 50
130130 self.helper.send_state(
131 self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
131 self.room_id,
132 EventTypes.PowerLevels,
133 pls,
134 tok=self.user_tok,
132135 )
133136
134137 # this is the point in the DAG where we make a fork
254257 self.assertIsNone(sr.event_id)
255258
256259 def test_update_function_state_row_limit(self):
257 """Test replication with many state events over several stream ids.
258 """
260 """Test replication with many state events over several stream ids."""
259261
260262 # we want to generate lots of state changes, but for this test, we want to
261263 # spread out the state changes over a few stream IDs.
281283 )
282284 pls["users"].update({u: 50 for u in user_ids})
283285 self.helper.send_state(
284 self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
286 self.room_id,
287 EventTypes.PowerLevels,
288 pls,
289 tok=self.user_tok,
285290 )
286291
287292 # this is the point in the DAG where we make a fork
2727 self.factory = ReplicationStreamProtocolFactory(hs)
2828
2929 def _make_client(self) -> Tuple[IProtocol, StringTransport]:
30 """Create a new direct TCP replication connection
31 """
30 """Create a new direct TCP replication connection"""
3231
3332 proto = self.factory.buildProtocol(("127.0.0.1", 0))
3433 transport = StringTransport()
7878 )
7979
8080 def test_no_auth(self):
81 """With no authentication the request should finish.
82 """
81 """With no authentication the request should finish."""
8382 channel = self._test_register()
8483 self.assertEqual(channel.code, 200)
8584
8887
8988 @override_config({"main_replication_secret": "my-secret"})
9089 def test_missing_auth(self):
91 """If the main process expects a secret that is not provided, an error results.
92 """
90 """If the main process expects a secret that is not provided, an error results."""
9391 channel = self._test_register()
9492 self.assertEqual(channel.code, 500)
9593
10098 }
10199 )
102100 def test_unauthorized(self):
103 """If the main process receives the wrong secret, an error results.
104 """
101 """If the main process receives the wrong secret, an error results."""
105102 channel = self._test_register()
106103 self.assertEqual(channel.code, 500)
107104
108105 @override_config({"worker_replication_secret": "my-secret"})
109106 def test_authorized(self):
110 """The request should finish when the worker provides the authentication header.
111 """
107 """The request should finish when the worker provides the authentication header."""
112108 channel = self._test_register()
113109 self.assertEqual(channel.code, 200)
114110
3434 return config
3535
3636 def test_register_single_worker(self):
37 """Test that registration works when using a single client reader worker.
38 """
37 """Test that registration works when using a single client reader worker."""
3938 worker_hs = self.make_worker_hs("synapse.app.client_reader")
4039 site = self._hs_to_site[worker_hs]
4140
6564 self.assertEqual(channel_2.json_body["user_id"], "@user:test")
6665
6766 def test_register_multi_worker(self):
68 """Test that registration works when using multiple client reader workers.
69 """
67 """Test that registration works when using multiple client reader workers."""
7068 worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
7169 worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
7270
3535
3636
3737 class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
38 """Checks running multiple media repos work correctly.
39 """
38 """Checks running multiple media repos work correctly."""
4039
4140 servlets = [
4241 admin.register_servlets_for_client_rest_resource,
123122 return channel, request
124123
125124 def test_basic(self):
126 """Test basic fetching of remote media from a single worker.
127 """
125 """Test basic fetching of remote media from a single worker."""
128126 hs1 = self.make_worker_hs("synapse.app.generic_worker")
129127
130128 channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
222220 self.assertEqual(start_count + 3, self._count_remote_thumbnails())
223221
224222 def _count_remote_media(self) -> int:
225 """Count the number of files in our remote media directory.
226 """
223 """Count the number of files in our remote media directory."""
227224 path = os.path.join(
228225 self.hs.get_media_repository().primary_base_path, "remote_content"
229226 )
230227 return sum(len(files) for _, _, files in os.walk(path))
231228
232229 def _count_remote_thumbnails(self) -> int:
233 """Count the number of files in our remote thumbnails directory.
234 """
230 """Count the number of files in our remote thumbnails directory."""
235231 path = os.path.join(
236232 self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
237233 )
2626
2727
2828 class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
29 """Checks pusher sharding works
30 """
29 """Checks pusher sharding works"""
3130
3231 servlets = [
3332 admin.register_servlets_for_client_rest_resource,
8786 return event_id
8887
8988 def test_send_push_single_worker(self):
90 """Test that registration works when using a pusher worker.
91 """
89 """Test that registration works when using a pusher worker."""
9290 http_client_mock = Mock(spec_set=["post_json_get_json"])
93 http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
94 {}
91 http_client_mock.post_json_get_json.side_effect = (
92 lambda *_, **__: defer.succeed({})
9593 )
9694
9795 self.make_worker_hs(
118116 )
119117
120118 def test_send_push_multiple_workers(self):
121 """Test that registration works when using sharded pusher workers.
122 """
119 """Test that registration works when using sharded pusher workers."""
123120 http_client_mock1 = Mock(spec_set=["post_json_get_json"])
124 http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
125 {}
121 http_client_mock1.post_json_get_json.side_effect = (
122 lambda *_, **__: defer.succeed({})
126123 )
127124
128125 self.make_worker_hs(
136133 )
137134
138135 http_client_mock2 = Mock(spec_set=["post_json_get_json"])
139 http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
140 {}
136 http_client_mock2.post_json_get_json.side_effect = (
137 lambda *_, **__: defer.succeed({})
141138 )
142139
143140 self.make_worker_hs(
2828
2929
3030 class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
31 """Checks event persisting sharding works
32 """
31 """Checks event persisting sharding works"""
3332
3433 # Event persister sharding requires postgres (due to needing
3534 # `MutliWriterIdGenerator`).
6261 return conf
6362
6463 def _create_room(self, room_id: str, user_id: str, tok: str):
65 """Create a room with given room_id
66 """
64 """Create a room with given room_id"""
6765
6866 # We control the room ID generation by patching out the
6967 # `_generate_room_id` method
9088 """
9189
9290 self.make_worker_hs(
93 "synapse.app.generic_worker", {"worker_name": "worker1"},
91 "synapse.app.generic_worker",
92 {"worker_name": "worker1"},
9493 )
9594
9695 self.make_worker_hs(
97 "synapse.app.generic_worker", {"worker_name": "worker2"},
96 "synapse.app.generic_worker",
97 {"worker_name": "worker2"},
9898 )
9999
100100 persisted_on_1 = False
138138 """
139139
140140 self.make_worker_hs(
141 "synapse.app.generic_worker", {"worker_name": "worker1"},
141 "synapse.app.generic_worker",
142 {"worker_name": "worker1"},
142143 )
143144
144145 worker_hs2 = self.make_worker_hs(
145 "synapse.app.generic_worker", {"worker_name": "worker2"},
146 "synapse.app.generic_worker",
147 {"worker_name": "worker2"},
146148 )
147149
148150 sync_hs = self.make_worker_hs(
149 "synapse.app.generic_worker", {"worker_name": "sync"},
151 "synapse.app.generic_worker",
152 {"worker_name": "sync"},
150153 )
151154 sync_hs_site = self._hs_to_site[sync_hs]
152155
322325 sync_hs_site,
323326 "GET",
324327 "/rooms/{}/messages?from={}&to={}&dir=f".format(
325 room_id2, vector_clock_token, prev_batch2,
328 room_id2,
329 vector_clock_token,
330 prev_batch2,
326331 ),
327332 access_token=access_token,
328333 )
129129 )
130130
131131 def _get_groups_user_is_in(self, access_token):
132 """Returns the list of groups the user is in (given their access token)
133 """
132 """Returns the list of groups the user is in (given their access token)"""
134133 channel = self.make_request(
135134 "GET", "/joined_groups".encode("ascii"), access_token=access_token
136135 )
141140
142141
143142 class QuarantineMediaTestCase(unittest.HomeserverTestCase):
144 """Test /quarantine_media admin API.
145 """
143 """Test /quarantine_media admin API."""
146144
147145 servlets = [
148146 synapse.rest.admin.register_servlets,
236234 # Attempt quarantine media APIs as non-admin
237235 url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
238236 channel = self.make_request(
239 "POST", url.encode("ascii"), access_token=non_admin_user_tok,
237 "POST",
238 url.encode("ascii"),
239 access_token=non_admin_user_tok,
240240 )
241241
242242 # Expect a forbidden error
249249 # And the roomID/userID endpoint
250250 url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
251251 channel = self.make_request(
252 "POST", url.encode("ascii"), access_token=non_admin_user_tok,
252 "POST",
253 url.encode("ascii"),
254 access_token=non_admin_user_tok,
253255 )
254256
255257 # Expect a forbidden error
293295 urllib.parse.quote(server_name),
294296 urllib.parse.quote(media_id),
295297 )
296 channel = self.make_request("POST", url, access_token=admin_user_tok,)
298 channel = self.make_request(
299 "POST",
300 url,
301 access_token=admin_user_tok,
302 )
297303 self.pump(1.0)
298304 self.assertEqual(200, int(channel.code), msg=channel.result["body"])
299305
345351 url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
346352 room_id
347353 )
348 channel = self.make_request("POST", url, access_token=admin_user_tok,)
354 channel = self.make_request(
355 "POST",
356 url,
357 access_token=admin_user_tok,
358 )
349359 self.pump(1.0)
350360 self.assertEqual(200, int(channel.code), msg=channel.result["body"])
351361 self.assertEqual(
390400 non_admin_user
391401 )
392402 channel = self.make_request(
393 "POST", url.encode("ascii"), access_token=admin_user_tok,
403 "POST",
404 url.encode("ascii"),
405 access_token=admin_user_tok,
394406 )
395407 self.pump(1.0)
396408 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
436448 non_admin_user
437449 )
438450 channel = self.make_request(
439 "POST", url.encode("ascii"), access_token=admin_user_tok,
451 "POST",
452 url.encode("ascii"),
453 access_token=admin_user_tok,
440454 )
441455 self.pump(1.0)
442456 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
6969 If the user is not a server admin, an error is returned.
7070 """
7171 channel = self.make_request(
72 "GET", self.url, access_token=self.other_user_token,
72 "GET",
73 self.url,
74 access_token=self.other_user_token,
7375 )
7476
7577 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
7678 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
7779
7880 channel = self.make_request(
79 "PUT", self.url, access_token=self.other_user_token,
81 "PUT",
82 self.url,
83 access_token=self.other_user_token,
8084 )
8185
8286 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
8387 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
8488
8589 channel = self.make_request(
86 "DELETE", self.url, access_token=self.other_user_token,
90 "DELETE",
91 self.url,
92 access_token=self.other_user_token,
8793 )
8894
8995 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
98104 % self.other_user_device_id
99105 )
100106
101 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
107 channel = self.make_request(
108 "GET",
109 url,
110 access_token=self.admin_user_tok,
111 )
102112
103113 self.assertEqual(404, channel.code, msg=channel.json_body)
104114 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
105115
106 channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
116 channel = self.make_request(
117 "PUT",
118 url,
119 access_token=self.admin_user_tok,
120 )
107121
108122 self.assertEqual(404, channel.code, msg=channel.json_body)
109123 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
110124
111 channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
125 channel = self.make_request(
126 "DELETE",
127 url,
128 access_token=self.admin_user_tok,
129 )
112130
113131 self.assertEqual(404, channel.code, msg=channel.json_body)
114132 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
122140 % self.other_user_device_id
123141 )
124142
125 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
143 channel = self.make_request(
144 "GET",
145 url,
146 access_token=self.admin_user_tok,
147 )
126148
127149 self.assertEqual(400, channel.code, msg=channel.json_body)
128150 self.assertEqual("Can only lookup local users", channel.json_body["error"])
129151
130 channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
152 channel = self.make_request(
153 "PUT",
154 url,
155 access_token=self.admin_user_tok,
156 )
131157
132158 self.assertEqual(400, channel.code, msg=channel.json_body)
133159 self.assertEqual("Can only lookup local users", channel.json_body["error"])
134160
135 channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
161 channel = self.make_request(
162 "DELETE",
163 url,
164 access_token=self.admin_user_tok,
165 )
136166
137167 self.assertEqual(400, channel.code, msg=channel.json_body)
138168 self.assertEqual("Can only lookup local users", channel.json_body["error"])
145175 self.other_user
146176 )
147177
148 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
178 channel = self.make_request(
179 "GET",
180 url,
181 access_token=self.admin_user_tok,
182 )
149183
150184 self.assertEqual(404, channel.code, msg=channel.json_body)
151185 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
152186
153 channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
154
155 self.assertEqual(200, channel.code, msg=channel.json_body)
156
157 channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
187 channel = self.make_request(
188 "PUT",
189 url,
190 access_token=self.admin_user_tok,
191 )
192
193 self.assertEqual(200, channel.code, msg=channel.json_body)
194
195 channel = self.make_request(
196 "DELETE",
197 url,
198 access_token=self.admin_user_tok,
199 )
158200
159201 # Delete unknown device returns status 200
160202 self.assertEqual(200, channel.code, msg=channel.json_body)
189231 self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
190232
191233 # Ensure the display name was not updated.
192 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
234 channel = self.make_request(
235 "GET",
236 self.url,
237 access_token=self.admin_user_tok,
238 )
193239
194240 self.assertEqual(200, channel.code, msg=channel.json_body)
195241 self.assertEqual("new display", channel.json_body["display_name"])
206252 )
207253 )
208254
209 channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,)
255 channel = self.make_request(
256 "PUT",
257 self.url,
258 access_token=self.admin_user_tok,
259 )
210260
211261 self.assertEqual(200, channel.code, msg=channel.json_body)
212262
213263 # Ensure the display name was not updated.
214 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
264 channel = self.make_request(
265 "GET",
266 self.url,
267 access_token=self.admin_user_tok,
268 )
215269
216270 self.assertEqual(200, channel.code, msg=channel.json_body)
217271 self.assertEqual("new display", channel.json_body["display_name"])
232286 self.assertEqual(200, channel.code, msg=channel.json_body)
233287
234288 # Check new display_name
235 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
289 channel = self.make_request(
290 "GET",
291 self.url,
292 access_token=self.admin_user_tok,
293 )
236294
237295 self.assertEqual(200, channel.code, msg=channel.json_body)
238296 self.assertEqual("new displayname", channel.json_body["display_name"])
241299 """
242300 Tests that a normal lookup for a device is successfully
243301 """
244 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
302 channel = self.make_request(
303 "GET",
304 self.url,
305 access_token=self.admin_user_tok,
306 )
245307
246308 self.assertEqual(200, channel.code, msg=channel.json_body)
247309 self.assertEqual(self.other_user, channel.json_body["user_id"])
263325
264326 # Delete device
265327 channel = self.make_request(
266 "DELETE", self.url, access_token=self.admin_user_tok,
328 "DELETE",
329 self.url,
330 access_token=self.admin_user_tok,
267331 )
268332
269333 self.assertEqual(200, channel.code, msg=channel.json_body)
305369 """
306370 other_user_token = self.login("user", "pass")
307371
308 channel = self.make_request("GET", self.url, access_token=other_user_token,)
372 channel = self.make_request(
373 "GET",
374 self.url,
375 access_token=other_user_token,
376 )
309377
310378 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
311379 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
315383 Tests that a lookup for a user that does not exist returns a 404
316384 """
317385 url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
318 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
386 channel = self.make_request(
387 "GET",
388 url,
389 access_token=self.admin_user_tok,
390 )
319391
320392 self.assertEqual(404, channel.code, msg=channel.json_body)
321393 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
326398 """
327399 url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
328400
329 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
401 channel = self.make_request(
402 "GET",
403 url,
404 access_token=self.admin_user_tok,
405 )
330406
331407 self.assertEqual(400, channel.code, msg=channel.json_body)
332408 self.assertEqual("Can only lookup local users", channel.json_body["error"])
338414 """
339415
340416 # Get devices
341 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
417 channel = self.make_request(
418 "GET",
419 self.url,
420 access_token=self.admin_user_tok,
421 )
342422
343423 self.assertEqual(200, channel.code, msg=channel.json_body)
344424 self.assertEqual(0, channel.json_body["total"])
354434 self.login("user", "pass")
355435
356436 # Get devices
357 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
437 channel = self.make_request(
438 "GET",
439 self.url,
440 access_token=self.admin_user_tok,
441 )
358442
359443 self.assertEqual(200, channel.code, msg=channel.json_body)
360444 self.assertEqual(number_devices, channel.json_body["total"])
403487 """
404488 other_user_token = self.login("user", "pass")
405489
406 channel = self.make_request("POST", self.url, access_token=other_user_token,)
490 channel = self.make_request(
491 "POST",
492 self.url,
493 access_token=other_user_token,
494 )
407495
408496 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
409497 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
413501 Tests that a lookup for a user that does not exist returns a 404
414502 """
415503 url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
416 channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
504 channel = self.make_request(
505 "POST",
506 url,
507 access_token=self.admin_user_tok,
508 )
417509
418510 self.assertEqual(404, channel.code, msg=channel.json_body)
419511 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
424516 """
425517 url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
426518
427 channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
519 channel = self.make_request(
520 "POST",
521 url,
522 access_token=self.admin_user_tok,
523 )
428524
429525 self.assertEqual(400, channel.code, msg=channel.json_body)
430526 self.assertEqual("Can only lookup local users", channel.json_body["error"])
5050 # Two rooms and two users. Every user sends and reports every room event
5151 for i in range(5):
5252 self._create_event_and_report(
53 room_id=self.room_id1, user_tok=self.other_user_tok,
53 room_id=self.room_id1,
54 user_tok=self.other_user_tok,
5455 )
5556 for i in range(5):
5657 self._create_event_and_report(
57 room_id=self.room_id2, user_tok=self.other_user_tok,
58 room_id=self.room_id2,
59 user_tok=self.other_user_tok,
5860 )
5961 for i in range(5):
6062 self._create_event_and_report(
61 room_id=self.room_id1, user_tok=self.admin_user_tok,
63 room_id=self.room_id1,
64 user_tok=self.admin_user_tok,
6265 )
6366 for i in range(5):
6467 self._create_event_and_report(
65 room_id=self.room_id2, user_tok=self.admin_user_tok,
68 room_id=self.room_id2,
69 user_tok=self.admin_user_tok,
6670 )
6771
6872 self.url = "/_synapse/admin/v1/event_reports"
8185 If the user is not a server admin, an error 403 is returned.
8286 """
8387
84 channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
88 channel = self.make_request(
89 "GET",
90 self.url,
91 access_token=self.other_user_tok,
92 )
8593
8694 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
8795 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
9199 Testing list of reported events
92100 """
93101
94 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
102 channel = self.make_request(
103 "GET",
104 self.url,
105 access_token=self.admin_user_tok,
106 )
95107
96108 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
97109 self.assertEqual(channel.json_body["total"], 20)
105117 """
106118
107119 channel = self.make_request(
108 "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
120 "GET",
121 self.url + "?limit=5",
122 access_token=self.admin_user_tok,
109123 )
110124
111125 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
120134 """
121135
122136 channel = self.make_request(
123 "GET", self.url + "?from=5", access_token=self.admin_user_tok,
137 "GET",
138 self.url + "?from=5",
139 access_token=self.admin_user_tok,
124140 )
125141
126142 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
135151 """
136152
137153 channel = self.make_request(
138 "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
154 "GET",
155 self.url + "?from=5&limit=10",
156 access_token=self.admin_user_tok,
139157 )
140158
141159 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
212230
213231 # fetch the most recent first, largest timestamp
214232 channel = self.make_request(
215 "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
233 "GET",
234 self.url + "?dir=b",
235 access_token=self.admin_user_tok,
216236 )
217237
218238 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
228248
229249 # fetch the oldest first, smallest timestamp
230250 channel = self.make_request(
231 "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
251 "GET",
252 self.url + "?dir=f",
253 access_token=self.admin_user_tok,
232254 )
233255
234256 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
248270 """
249271
250272 channel = self.make_request(
251 "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
273 "GET",
274 self.url + "?dir=bar",
275 access_token=self.admin_user_tok,
252276 )
253277
254278 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
261285 """
262286
263287 channel = self.make_request(
264 "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
288 "GET",
289 self.url + "?limit=-5",
290 access_token=self.admin_user_tok,
265291 )
266292
267293 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
273299 """
274300
275301 channel = self.make_request(
276 "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
302 "GET",
303 self.url + "?from=-5",
304 access_token=self.admin_user_tok,
277305 )
278306
279307 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
287315 # `next_token` does not appear
288316 # Number of results is the number of entries
289317 channel = self.make_request(
290 "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
318 "GET",
319 self.url + "?limit=20",
320 access_token=self.admin_user_tok,
291321 )
292322
293323 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
298328 # `next_token` does not appear
299329 # Number of max results is larger than the number of entries
300330 channel = self.make_request(
301 "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
331 "GET",
332 self.url + "?limit=21",
333 access_token=self.admin_user_tok,
302334 )
303335
304336 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
309341 # `next_token` does appear
310342 # Number of max results is smaller than the number of entries
311343 channel = self.make_request(
312 "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
344 "GET",
345 self.url + "?limit=19",
346 access_token=self.admin_user_tok,
313347 )
314348
315349 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
321355 # Set `from` to value of `next_token` for request remaining entries
322356 # `next_token` does not appear
323357 channel = self.make_request(
324 "GET", self.url + "?from=19", access_token=self.admin_user_tok,
358 "GET",
359 self.url + "?from=19",
360 access_token=self.admin_user_tok,
325361 )
326362
327363 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
330366 self.assertNotIn("next_token", channel.json_body)
331367
332368 def _create_event_and_report(self, room_id, user_tok):
333 """Create and report events
334 """
369 """Create and report events"""
335370 resp = self.helper.send(room_id, tok=user_tok)
336371 event_id = resp["event_id"]
337372
344379 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
345380
346381 def _check_fields(self, content):
347 """Checks that all attributes are present in an event report
348 """
382 """Checks that all attributes are present in an event report"""
349383 for c in content:
350384 self.assertIn("id", c)
351385 self.assertIn("received_ts", c)
380414 self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
381415
382416 self._create_event_and_report(
383 room_id=self.room_id1, user_tok=self.other_user_tok,
417 room_id=self.room_id1,
418 user_tok=self.other_user_tok,
384419 )
385420
386421 # first created event report gets `id`=2
400435 If the user is not a server admin, an error 403 is returned.
401436 """
402437
403 channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
438 channel = self.make_request(
439 "GET",
440 self.url,
441 access_token=self.other_user_tok,
442 )
404443
405444 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
406445 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
410449 Testing get a reported event
411450 """
412451
413 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
452 channel = self.make_request(
453 "GET",
454 self.url,
455 access_token=self.admin_user_tok,
456 )
414457
415458 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
416459 self._check_fields(channel.json_body)
478521 self.assertEqual("Event report not found", channel.json_body["error"])
479522
480523 def _create_event_and_report(self, room_id, user_tok):
481 """Create and report events
482 """
524 """Create and report events"""
483525 resp = self.helper.send(room_id, tok=user_tok)
484526 event_id = resp["event_id"]
485527
492534 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
493535
494536 def _check_fields(self, content):
495 """Checks that all attributes are present in a event report
496 """
537 """Checks that all attributes are present in a event report"""
497538 self.assertIn("id", content)
498539 self.assertIn("received_ts", content)
499540 self.assertIn("room_id", content)
6262
6363 url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
6464
65 channel = self.make_request("DELETE", url, access_token=self.other_user_token,)
65 channel = self.make_request(
66 "DELETE",
67 url,
68 access_token=self.other_user_token,
69 )
6670
6771 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
6872 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
7377 """
7478 url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
7579
76 channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
80 channel = self.make_request(
81 "DELETE",
82 url,
83 access_token=self.admin_user_tok,
84 )
7785
7886 self.assertEqual(404, channel.code, msg=channel.json_body)
7987 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
8492 """
8593 url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
8694
87 channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
95 channel = self.make_request(
96 "DELETE",
97 url,
98 access_token=self.admin_user_tok,
99 )
88100
89101 self.assertEqual(400, channel.code, msg=channel.json_body)
90102 self.assertEqual("Can only delete local media", channel.json_body["error"])
138150 url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id)
139151
140152 # Delete media
141 channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
153 channel = self.make_request(
154 "DELETE",
155 url,
156 access_token=self.admin_user_tok,
157 )
142158
143159 self.assertEqual(200, channel.code, msg=channel.json_body)
144160 self.assertEqual(1, channel.json_body["total"])
145161 self.assertEqual(
146 media_id, channel.json_body["deleted_media"][0],
162 media_id,
163 channel.json_body["deleted_media"][0],
147164 )
148165
149166 # Attempt to access media
206223 self.other_user_token = self.login("user", "pass")
207224
208225 channel = self.make_request(
209 "POST", self.url, access_token=self.other_user_token,
226 "POST",
227 self.url,
228 access_token=self.other_user_token,
210229 )
211230
212231 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
219238 url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
220239
221240 channel = self.make_request(
222 "POST", url + "?before_ts=1234", access_token=self.admin_user_tok,
241 "POST",
242 url + "?before_ts=1234",
243 access_token=self.admin_user_tok,
223244 )
224245
225246 self.assertEqual(400, channel.code, msg=channel.json_body)
229250 """
230251 If the parameter `before_ts` is missing, an error is returned.
231252 """
232 channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,)
253 channel = self.make_request(
254 "POST",
255 self.url,
256 access_token=self.admin_user_tok,
257 )
233258
234259 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
235260 self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
242267 If parameters are invalid, an error is returned.
243268 """
244269 channel = self.make_request(
245 "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok,
270 "POST",
271 self.url + "?before_ts=-1234",
272 access_token=self.admin_user_tok,
246273 )
247274
248275 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
303330 self.assertEqual(200, channel.code, msg=channel.json_body)
304331 self.assertEqual(1, channel.json_body["total"])
305332 self.assertEqual(
306 media_id, channel.json_body["deleted_media"][0],
333 media_id,
334 channel.json_body["deleted_media"][0],
307335 )
308336
309337 self._access_media(server_and_media_id, False)
339367 self.assertEqual(200, channel.code, msg=channel.json_body)
340368 self.assertEqual(1, channel.json_body["total"])
341369 self.assertEqual(
342 server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
370 server_and_media_id.split("/")[1],
371 channel.json_body["deleted_media"][0],
343372 )
344373
345374 self._access_media(server_and_media_id, False)
373402 self.assertEqual(200, channel.code, msg=channel.json_body)
374403 self.assertEqual(1, channel.json_body["total"])
375404 self.assertEqual(
376 server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
405 server_and_media_id.split("/")[1],
406 channel.json_body["deleted_media"][0],
377407 )
378408
379409 self._access_media(server_and_media_id, False)
416446 self.assertEqual(200, channel.code, msg=channel.json_body)
417447 self.assertEqual(1, channel.json_body["total"])
418448 self.assertEqual(
419 server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
449 server_and_media_id.split("/")[1],
450 channel.json_body["deleted_media"][0],
420451 )
421452
422453 self._access_media(server_and_media_id, False)
460491 self.assertEqual(200, channel.code, msg=channel.json_body)
461492 self.assertEqual(1, channel.json_body["total"])
462493 self.assertEqual(
463 server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
494 server_and_media_id.split("/")[1],
495 channel.json_body["deleted_media"][0],
464496 )
465497
466498 self._access_media(server_and_media_id, False)
126126 self._assert_peek(room_id, expect_code=403)
127127
128128 def _assert_peek(self, room_id, expect_code):
129 """Assert that the admin user can (or cannot) peek into the room.
130 """
129 """Assert that the admin user can (or cannot) peek into the room."""
131130
132131 url = "rooms/%s/initialSync" % (room_id,)
133132 channel = self.make_request(
185184 """
186185
187186 channel = self.make_request(
188 "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
187 "POST",
188 self.url,
189 json.dumps({}),
190 access_token=self.other_user_tok,
189191 )
190192
191193 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
198200 url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
199201
200202 channel = self.make_request(
201 "POST", url, json.dumps({}), access_token=self.admin_user_tok,
203 "POST",
204 url,
205 json.dumps({}),
206 access_token=self.admin_user_tok,
202207 )
203208
204209 self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
211216 url = "/_synapse/admin/v1/rooms/invalidroom/delete"
212217
213218 channel = self.make_request(
214 "POST", url, json.dumps({}), access_token=self.admin_user_tok,
219 "POST",
220 url,
221 json.dumps({}),
222 access_token=self.admin_user_tok,
215223 )
216224
217225 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
218226 self.assertEqual(
219 "invalidroom is not a legal room ID", channel.json_body["error"],
227 "invalidroom is not a legal room ID",
228 channel.json_body["error"],
220229 )
221230
222231 def test_new_room_user_does_not_exist(self):
253262
254263 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
255264 self.assertEqual(
256 "User must be our own: @not:exist.bla", channel.json_body["error"],
265 "User must be our own: @not:exist.bla",
266 channel.json_body["error"],
257267 )
258268
259269 def test_block_is_not_bool(self):
490500 self._assert_peek(self.room_id, expect_code=403)
491501
492502 def _is_blocked(self, room_id, expect=True):
493 """Assert that the room is blocked or not
494 """
503 """Assert that the room is blocked or not"""
495504 d = self.store.is_room_blocked(room_id)
496505 if expect:
497506 self.assertTrue(self.get_success(d))
499508 self.assertIsNone(self.get_success(d))
500509
501510 def _has_no_members(self, room_id):
502 """Assert there is now no longer anyone in the room
503 """
511 """Assert there is now no longer anyone in the room"""
504512 users_in_room = self.get_success(self.store.get_users_in_room(room_id))
505513 self.assertEqual([], users_in_room)
506514
507515 def _is_member(self, room_id, user_id):
508 """Test that user is member of the room
509 """
516 """Test that user is member of the room"""
510517 users_in_room = self.get_success(self.store.get_users_in_room(room_id))
511518 self.assertIn(user_id, users_in_room)
512519
513520 def _is_purged(self, room_id):
514 """Test that the following tables have been purged of all rows related to the room.
515 """
521 """Test that the following tables have been purged of all rows related to the room."""
516522 for table in PURGE_TABLES:
517523 count = self.get_success(
518524 self.store.db_pool.simple_select_one_onecol(
526532 self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
527533
528534 def _assert_peek(self, room_id, expect_code):
529 """Assert that the admin user can (or cannot) peek into the room.
530 """
535 """Assert that the admin user can (or cannot) peek into the room."""
531536
532537 url = "rooms/%s/initialSync" % (room_id,)
533538 channel = self.make_request(
547552
548553
549554 class PurgeRoomTestCase(unittest.HomeserverTestCase):
550 """Test /purge_room admin API.
551 """
555 """Test /purge_room admin API."""
552556
553557 servlets = [
554558 synapse.rest.admin.register_servlets,
593597
594598
595599 class RoomTestCase(unittest.HomeserverTestCase):
596 """Test /room admin API.
597 """
600 """Test /room admin API."""
598601
599602 servlets = [
600603 synapse.rest.admin.register_servlets,
622625 # Request the list of rooms
623626 url = "/_synapse/admin/v1/rooms"
624627 channel = self.make_request(
625 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
628 "GET",
629 url.encode("ascii"),
630 access_token=self.admin_user_tok,
626631 )
627632
628633 # Check request completed successfully
684689 # Set the name of the rooms so we get a consistent returned ordering
685690 for idx, room_id in enumerate(room_ids):
686691 self.helper.send_state(
687 room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
692 room_id,
693 "m.room.name",
694 {"name": str(idx)},
695 tok=self.admin_user_tok,
688696 )
689697
690698 # Request the list of rooms
703711 "name",
704712 )
705713 channel = self.make_request(
706 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
714 "GET",
715 url.encode("ascii"),
716 access_token=self.admin_user_tok,
707717 )
708718 self.assertEqual(
709719 200, int(channel.result["code"]), msg=channel.result["body"]
743753
744754 url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
745755 channel = self.make_request(
746 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
756 "GET",
757 url.encode("ascii"),
758 access_token=self.admin_user_tok,
747759 )
748760 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
749761
787799
788800 # Set a name for the room
789801 self.helper.send_state(
790 room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
802 room_id,
803 "m.room.name",
804 {"name": test_room_name},
805 tok=self.admin_user_tok,
791806 )
792807
793808 # Request the list of rooms
794809 url = "/_synapse/admin/v1/rooms"
795810 channel = self.make_request(
796 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
811 "GET",
812 url.encode("ascii"),
813 access_token=self.admin_user_tok,
797814 )
798815 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
799816
859876 )
860877
861878 def _order_test(
862 order_type: str, expected_room_list: List[str], reverse: bool = False,
879 order_type: str,
880 expected_room_list: List[str],
881 reverse: bool = False,
863882 ):
864883 """Request the list of rooms in a certain order. Assert that order is what
865884 we expect
874893 if reverse:
875894 url += "&dir=b"
876895 channel = self.make_request(
877 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
896 "GET",
897 url.encode("ascii"),
898 access_token=self.admin_user_tok,
878899 )
879900 self.assertEqual(200, channel.code, msg=channel.json_body)
880901
906927
907928 # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
908929 self.helper.send_state(
909 room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
930 room_id_1,
931 "m.room.name",
932 {"name": "A"},
933 tok=self.admin_user_tok,
910934 )
911935 self.helper.send_state(
912 room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
936 room_id_2,
937 "m.room.name",
938 {"name": "B"},
939 tok=self.admin_user_tok,
913940 )
914941 self.helper.send_state(
915 room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
942 room_id_3,
943 "m.room.name",
944 {"name": "C"},
945 tok=self.admin_user_tok,
916946 )
917947
918948 # Set room canonical room aliases
9891019
9901020 # Set the name for each room
9911021 self.helper.send_state(
992 room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
1022 room_id_1,
1023 "m.room.name",
1024 {"name": room_name_1},
1025 tok=self.admin_user_tok,
9931026 )
9941027 self.helper.send_state(
995 room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
1028 room_id_2,
1029 "m.room.name",
1030 {"name": room_name_2},
1031 tok=self.admin_user_tok,
9961032 )
9971033
9981034 def _search_test(
10101046 """
10111047 url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
10121048 channel = self.make_request(
1013 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
1049 "GET",
1050 url.encode("ascii"),
1051 access_token=self.admin_user_tok,
10141052 )
10151053 self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
10161054
10701108
10711109 # Set the name for each room
10721110 self.helper.send_state(
1073 room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
1111 room_id_1,
1112 "m.room.name",
1113 {"name": room_name_1},
1114 tok=self.admin_user_tok,
10741115 )
10751116 self.helper.send_state(
1076 room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
1117 room_id_2,
1118 "m.room.name",
1119 {"name": room_name_2},
1120 tok=self.admin_user_tok,
10771121 )
10781122
10791123 url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
10801124 channel = self.make_request(
1081 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
1125 "GET",
1126 url.encode("ascii"),
1127 access_token=self.admin_user_tok,
10821128 )
10831129 self.assertEqual(200, channel.code, msg=channel.json_body)
10841130
11081154
11091155 url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
11101156 channel = self.make_request(
1111 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
1157 "GET",
1158 url.encode("ascii"),
1159 access_token=self.admin_user_tok,
11121160 )
11131161 self.assertEqual(200, channel.code, msg=channel.json_body)
11141162 self.assertEqual(1, channel.json_body["joined_local_devices"])
11201168
11211169 url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
11221170 channel = self.make_request(
1123 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
1171 "GET",
1172 url.encode("ascii"),
1173 access_token=self.admin_user_tok,
11241174 )
11251175 self.assertEqual(200, channel.code, msg=channel.json_body)
11261176 self.assertEqual(2, channel.json_body["joined_local_devices"])
11301180 self.helper.leave(room_id_1, user_1, tok=user_tok_1)
11311181 url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
11321182 channel = self.make_request(
1133 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
1183 "GET",
1184 url.encode("ascii"),
1185 access_token=self.admin_user_tok,
11341186 )
11351187 self.assertEqual(200, channel.code, msg=channel.json_body)
11361188 self.assertEqual(0, channel.json_body["joined_local_devices"])
11591211
11601212 url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
11611213 channel = self.make_request(
1162 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
1214 "GET",
1215 url.encode("ascii"),
1216 access_token=self.admin_user_tok,
11631217 )
11641218 self.assertEqual(200, channel.code, msg=channel.json_body)
11651219
11701224
11711225 url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
11721226 channel = self.make_request(
1173 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
1227 "GET",
1228 url.encode("ascii"),
1229 access_token=self.admin_user_tok,
11741230 )
11751231 self.assertEqual(200, channel.code, msg=channel.json_body)
11761232
11861242
11871243 url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
11881244 channel = self.make_request(
1189 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
1245 "GET",
1246 url.encode("ascii"),
1247 access_token=self.admin_user_tok,
11901248 )
11911249 self.assertEqual(200, channel.code, msg=channel.json_body)
11921250 self.assertIn("state", channel.json_body)
13411399 # Validate if user is a member of the room
13421400
13431401 channel = self.make_request(
1344 "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
1402 "GET",
1403 "/_matrix/client/r0/joined_rooms",
1404 access_token=self.second_tok,
13451405 )
13461406 self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
13471407 self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
13881448 # Validate if server admin is a member of the room
13891449
13901450 channel = self.make_request(
1391 "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
1451 "GET",
1452 "/_matrix/client/r0/joined_rooms",
1453 access_token=self.admin_user_tok,
13921454 )
13931455 self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
13941456 self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
14101472 # Validate if user is a member of the room
14111473
14121474 channel = self.make_request(
1413 "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
1475 "GET",
1476 "/_matrix/client/r0/joined_rooms",
1477 access_token=self.second_tok,
14141478 )
14151479 self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
14161480 self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
14391503 # Validate if user is a member of the room
14401504
14411505 channel = self.make_request(
1442 "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
1506 "GET",
1507 "/_matrix/client/r0/joined_rooms",
1508 access_token=self.second_tok,
14431509 )
14441510 self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
14451511 self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
1512
1513 def test_context_as_non_admin(self):
1514 """
1515 Test that, without being admin, one cannot use the context admin API
1516 """
1517 # Create a room.
1518 user_id = self.register_user("test", "test")
1519 user_tok = self.login("test", "test")
1520
1521 self.register_user("test_2", "test")
1522 user_tok_2 = self.login("test_2", "test")
1523
1524 room_id = self.helper.create_room_as(user_id, tok=user_tok)
1525
1526 # Populate the room with events.
1527 events = []
1528 for i in range(30):
1529 events.append(
1530 self.helper.send_event(
1531 room_id, "com.example.test", content={"index": i}, tok=user_tok
1532 )
1533 )
1534
1535 # Now attempt to find the context using the admin API without being admin.
1536 midway = (len(events) - 1) // 2
1537 for tok in [user_tok, user_tok_2]:
1538 channel = self.make_request(
1539 "GET",
1540 "/_synapse/admin/v1/rooms/%s/context/%s"
1541 % (room_id, events[midway]["event_id"]),
1542 access_token=tok,
1543 )
1544 self.assertEquals(
1545 403, int(channel.result["code"]), msg=channel.result["body"]
1546 )
1547 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
1548
1549 def test_context_as_admin(self):
1550 """
1551 Test that, as admin, we can find the context of an event without having joined the room.
1552 """
1553
1554 # Create a room. We're not part of it.
1555 user_id = self.register_user("test", "test")
1556 user_tok = self.login("test", "test")
1557 room_id = self.helper.create_room_as(user_id, tok=user_tok)
1558
1559 # Populate the room with events.
1560 events = []
1561 for i in range(30):
1562 events.append(
1563 self.helper.send_event(
1564 room_id, "com.example.test", content={"index": i}, tok=user_tok
1565 )
1566 )
1567
1568 # Now let's fetch the context for this room.
1569 midway = (len(events) - 1) // 2
1570 channel = self.make_request(
1571 "GET",
1572 "/_synapse/admin/v1/rooms/%s/context/%s"
1573 % (room_id, events[midway]["event_id"]),
1574 access_token=self.admin_user_tok,
1575 )
1576 self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
1577 self.assertEquals(
1578 channel.json_body["event"]["event_id"], events[midway]["event_id"]
1579 )
1580
1581 for i, found_event in enumerate(channel.json_body["events_before"]):
1582 for j, posted_event in enumerate(events):
1583 if found_event["event_id"] == posted_event["event_id"]:
1584 self.assertTrue(j < midway)
1585 break
1586 else:
1587 self.fail("Event %s from events_before not found" % j)
1588
1589 for i, found_event in enumerate(channel.json_body["events_after"]):
1590 for j, posted_event in enumerate(events):
1591 if found_event["event_id"] == posted_event["event_id"]:
1592 self.assertTrue(j > midway)
1593 break
1594 else:
1595 self.fail("Event %s from events_after not found" % j)
14461596
14471597
14481598 class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
14701620 )
14711621
14721622 def test_public_room(self):
1473 """Test that getting admin in a public room works.
1474 """
1623 """Test that getting admin in a public room works."""
14751624 room_id = self.helper.create_room_as(
14761625 self.creator, tok=self.creator_tok, is_public=True
14771626 )
14961645 )
14971646
14981647 def test_private_room(self):
1499 """Test that getting admin in a private room works and we get invited.
1500 """
1648 """Test that getting admin in a private room works and we get invited."""
15011649 room_id = self.helper.create_room_as(
1502 self.creator, tok=self.creator_tok, is_public=False,
1650 self.creator,
1651 tok=self.creator_tok,
1652 is_public=False,
15031653 )
15041654
15051655 channel = self.make_request(
15231673 )
15241674
15251675 def test_other_user(self):
1526 """Test that giving admin in a public room works to a non-admin user works.
1527 """
1676 """Test that giving admin in a public room works to a non-admin user works."""
15281677 room_id = self.helper.create_room_as(
15291678 self.creator, tok=self.creator_tok, is_public=True
15301679 )
15491698 )
15501699
15511700 def test_not_enough_power(self):
1552 """Test that we get a sensible error if there are no local room admins.
1553 """
1701 """Test that we get a sensible error if there are no local room admins."""
15541702 room_id = self.helper.create_room_as(
15551703 self.creator, tok=self.creator_tok, is_public=True
15561704 )
5454 If the user is not a server admin, an error 403 is returned.
5555 """
5656 channel = self.make_request(
57 "GET", self.url, json.dumps({}), access_token=self.other_user_tok,
57 "GET",
58 self.url,
59 json.dumps({}),
60 access_token=self.other_user_tok,
5861 )
5962
6063 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
6669 """
6770 # unkown order_by
6871 channel = self.make_request(
69 "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok,
72 "GET",
73 self.url + "?order_by=bar",
74 access_token=self.admin_user_tok,
7075 )
7176
7277 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
7479
7580 # negative from
7681 channel = self.make_request(
77 "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
82 "GET",
83 self.url + "?from=-5",
84 access_token=self.admin_user_tok,
7885 )
7986
8087 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
8289
8390 # negative limit
8491 channel = self.make_request(
85 "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
92 "GET",
93 self.url + "?limit=-5",
94 access_token=self.admin_user_tok,
8695 )
8796
8897 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
9099
91100 # negative from_ts
92101 channel = self.make_request(
93 "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok,
102 "GET",
103 self.url + "?from_ts=-1234",
104 access_token=self.admin_user_tok,
94105 )
95106
96107 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
98109
99110 # negative until_ts
100111 channel = self.make_request(
101 "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok,
112 "GET",
113 self.url + "?until_ts=-1234",
114 access_token=self.admin_user_tok,
102115 )
103116
104117 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
116129
117130 # empty search term
118131 channel = self.make_request(
119 "GET", self.url + "?search_term=", access_token=self.admin_user_tok,
132 "GET",
133 self.url + "?search_term=",
134 access_token=self.admin_user_tok,
120135 )
121136
122137 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
124139
125140 # invalid search order
126141 channel = self.make_request(
127 "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
142 "GET",
143 self.url + "?dir=bar",
144 access_token=self.admin_user_tok,
128145 )
129146
130147 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
137154 self._create_users_with_media(10, 2)
138155
139156 channel = self.make_request(
140 "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
157 "GET",
158 self.url + "?limit=5",
159 access_token=self.admin_user_tok,
141160 )
142161
143162 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
153172 self._create_users_with_media(20, 2)
154173
155174 channel = self.make_request(
156 "GET", self.url + "?from=5", access_token=self.admin_user_tok,
175 "GET",
176 self.url + "?from=5",
177 access_token=self.admin_user_tok,
157178 )
158179
159180 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
169190 self._create_users_with_media(20, 2)
170191
171192 channel = self.make_request(
172 "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
193 "GET",
194 self.url + "?from=5&limit=10",
195 access_token=self.admin_user_tok,
173196 )
174197
175198 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
189212 # `next_token` does not appear
190213 # Number of results is the number of entries
191214 channel = self.make_request(
192 "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
215 "GET",
216 self.url + "?limit=20",
217 access_token=self.admin_user_tok,
193218 )
194219
195220 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
200225 # `next_token` does not appear
201226 # Number of max results is larger than the number of entries
202227 channel = self.make_request(
203 "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
228 "GET",
229 self.url + "?limit=21",
230 access_token=self.admin_user_tok,
204231 )
205232
206233 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
211238 # `next_token` does appear
212239 # Number of max results is smaller than the number of entries
213240 channel = self.make_request(
214 "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
241 "GET",
242 self.url + "?limit=19",
243 access_token=self.admin_user_tok,
215244 )
216245
217246 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
222251 # Set `from` to value of `next_token` for request remaining entries
223252 # Check `next_token` does not appear
224253 channel = self.make_request(
225 "GET", self.url + "?from=19", access_token=self.admin_user_tok,
254 "GET",
255 self.url + "?from=19",
256 access_token=self.admin_user_tok,
226257 )
227258
228259 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
236267 if users have no media created
237268 """
238269
239 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
270 channel = self.make_request(
271 "GET",
272 self.url,
273 access_token=self.admin_user_tok,
274 )
240275
241276 self.assertEqual(200, channel.code, msg=channel.json_body)
242277 self.assertEqual(0, channel.json_body["total"])
263298 # order by user_id
264299 self._order_test("user_id", ["@user_a:test", "@user_b:test", "@user_c:test"])
265300 self._order_test(
266 "user_id", ["@user_a:test", "@user_b:test", "@user_c:test"], "f",
267 )
268 self._order_test(
269 "user_id", ["@user_c:test", "@user_b:test", "@user_a:test"], "b",
301 "user_id",
302 ["@user_a:test", "@user_b:test", "@user_c:test"],
303 "f",
304 )
305 self._order_test(
306 "user_id",
307 ["@user_c:test", "@user_b:test", "@user_a:test"],
308 "b",
270309 )
271310
272311 # order by displayname
274313 "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"]
275314 )
276315 self._order_test(
277 "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"], "f",
278 )
279 self._order_test(
280 "displayname", ["@user_a:test", "@user_b:test", "@user_c:test"], "b",
316 "displayname",
317 ["@user_c:test", "@user_b:test", "@user_a:test"],
318 "f",
319 )
320 self._order_test(
321 "displayname",
322 ["@user_a:test", "@user_b:test", "@user_c:test"],
323 "b",
281324 )
282325
283326 # order by media_length
284327 self._order_test(
285 "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"],
286 )
287 self._order_test(
288 "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], "f",
289 )
290 self._order_test(
291 "media_length", ["@user_b:test", "@user_c:test", "@user_a:test"], "b",
328 "media_length",
329 ["@user_a:test", "@user_c:test", "@user_b:test"],
330 )
331 self._order_test(
332 "media_length",
333 ["@user_a:test", "@user_c:test", "@user_b:test"],
334 "f",
335 )
336 self._order_test(
337 "media_length",
338 ["@user_b:test", "@user_c:test", "@user_a:test"],
339 "b",
292340 )
293341
294342 # order by media_count
295343 self._order_test(
296 "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"],
297 )
298 self._order_test(
299 "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], "f",
300 )
301 self._order_test(
302 "media_count", ["@user_b:test", "@user_c:test", "@user_a:test"], "b",
344 "media_count",
345 ["@user_a:test", "@user_c:test", "@user_b:test"],
346 )
347 self._order_test(
348 "media_count",
349 ["@user_a:test", "@user_c:test", "@user_b:test"],
350 "f",
351 )
352 self._order_test(
353 "media_count",
354 ["@user_b:test", "@user_c:test", "@user_a:test"],
355 "b",
303356 )
304357
305358 def test_from_until_ts(self):
312365 ts1 = self.clock.time_msec()
313366
314367 # list all media when filter is not set
315 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
368 channel = self.make_request(
369 "GET",
370 self.url,
371 access_token=self.admin_user_tok,
372 )
316373 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
317374 self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
318375
319376 # filter media starting at `ts1` after creating first media
320377 # result is 0
321378 channel = self.make_request(
322 "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok,
379 "GET",
380 self.url + "?from_ts=%s" % (ts1,),
381 access_token=self.admin_user_tok,
323382 )
324383 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
325384 self.assertEqual(channel.json_body["total"], 0)
341400
342401 # filter media until `ts2` and earlier
343402 channel = self.make_request(
344 "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok,
403 "GET",
404 self.url + "?until_ts=%s" % (ts2,),
405 access_token=self.admin_user_tok,
345406 )
346407 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
347408 self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
350411 self._create_users_with_media(20, 1)
351412
352413 # check without filter get all users
353 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
414 channel = self.make_request(
415 "GET",
416 self.url,
417 access_token=self.admin_user_tok,
418 )
354419 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
355420 self.assertEqual(channel.json_body["total"], 20)
356421
375440
376441 # filter and get empty result
377442 channel = self.make_request(
378 "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok,
443 "GET",
444 self.url + "?search_term=foobar",
445 access_token=self.admin_user_tok,
379446 )
380447 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
381448 self.assertEqual(channel.json_body["total"], 0)
440507 if dir is not None and dir in ("b", "f"):
441508 url += "&dir=%s" % (dir,)
442509 channel = self.make_request(
443 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
510 "GET",
511 url.encode("ascii"),
512 access_token=self.admin_user_tok,
444513 )
445514 self.assertEqual(200, channel.code, msg=channel.json_body)
446515 self.assertEqual(channel.json_body["total"], len(expected_user_list))
527527 search_field: Field which is to request: `name` or `user_id`
528528 expected_http_code: The expected http code for the request
529529 """
530 url = self.url + "?%s=%s" % (search_field, search_term,)
530 url = self.url + "?%s=%s" % (
531 search_field,
532 search_term,
533 )
531534 channel = self.make_request(
532 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
535 "GET",
536 url.encode("ascii"),
537 access_token=self.admin_user_tok,
533538 )
534539 self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
535540
589594
590595 # negative limit
591596 channel = self.make_request(
592 "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
597 "GET",
598 self.url + "?limit=-5",
599 access_token=self.admin_user_tok,
593600 )
594601
595602 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
597604
598605 # negative from
599606 channel = self.make_request(
600 "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
607 "GET",
608 self.url + "?from=-5",
609 access_token=self.admin_user_tok,
601610 )
602611
603612 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
605614
606615 # invalid guests
607616 channel = self.make_request(
608 "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
617 "GET",
618 self.url + "?guests=not_bool",
619 access_token=self.admin_user_tok,
609620 )
610621
611622 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
613624
614625 # invalid deactivated
615626 channel = self.make_request(
616 "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
627 "GET",
628 self.url + "?deactivated=not_bool",
629 access_token=self.admin_user_tok,
617630 )
618631
619632 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
629642 self._create_users(number_users - 1)
630643
631644 channel = self.make_request(
632 "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
645 "GET",
646 self.url + "?limit=5",
647 access_token=self.admin_user_tok,
633648 )
634649
635650 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
648663 self._create_users(number_users - 1)
649664
650665 channel = self.make_request(
651 "GET", self.url + "?from=5", access_token=self.admin_user_tok,
666 "GET",
667 self.url + "?from=5",
668 access_token=self.admin_user_tok,
652669 )
653670
654671 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
667684 self._create_users(number_users - 1)
668685
669686 channel = self.make_request(
670 "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
687 "GET",
688 self.url + "?from=5&limit=10",
689 access_token=self.admin_user_tok,
671690 )
672691
673692 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
688707 # `next_token` does not appear
689708 # Number of results is the number of entries
690709 channel = self.make_request(
691 "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
710 "GET",
711 self.url + "?limit=20",
712 access_token=self.admin_user_tok,
692713 )
693714
694715 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
699720 # `next_token` does not appear
700721 # Number of max results is larger than the number of entries
701722 channel = self.make_request(
702 "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
723 "GET",
724 self.url + "?limit=21",
725 access_token=self.admin_user_tok,
703726 )
704727
705728 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
710733 # `next_token` does appear
711734 # Number of max results is smaller than the number of entries
712735 channel = self.make_request(
713 "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
736 "GET",
737 self.url + "?limit=19",
738 access_token=self.admin_user_tok,
714739 )
715740
716741 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
722747 # Set `from` to value of `next_token` for request remaining entries
723748 # `next_token` does not appear
724749 channel = self.make_request(
725 "GET", self.url + "?from=19", access_token=self.admin_user_tok,
750 "GET",
751 self.url + "?from=19",
752 access_token=self.admin_user_tok,
726753 )
727754
728755 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
741768 self.assertIn("admin", u)
742769 self.assertIn("user_type", u)
743770 self.assertIn("deactivated", u)
771 self.assertIn("shadow_banned", u)
744772 self.assertIn("displayname", u)
745773 self.assertIn("avatar_url", u)
746774
752780 """
753781 for i in range(1, number_users + 1):
754782 self.register_user(
755 "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
783 "user%d" % i,
784 "pass%d" % i,
785 admin=False,
786 displayname="Name %d" % i,
756787 )
757788
758789
807838 self.assertEqual("You are not a server admin", channel.json_body["error"])
808839
809840 channel = self.make_request(
810 "POST", url, access_token=self.other_user_token, content=b"{}",
841 "POST",
842 url,
843 access_token=self.other_user_token,
844 content=b"{}",
811845 )
812846
813847 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
861895
862896 # Get user
863897 channel = self.make_request(
864 "GET", self.url_other_user, access_token=self.admin_user_tok,
898 "GET",
899 self.url_other_user,
900 access_token=self.admin_user_tok,
865901 )
866902
867903 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
885921
886922 # Get user
887923 channel = self.make_request(
888 "GET", self.url_other_user, access_token=self.admin_user_tok,
924 "GET",
925 self.url_other_user,
926 access_token=self.admin_user_tok,
889927 )
890928
891929 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
904942
905943 # Get user
906944 channel = self.make_request(
907 "GET", self.url_other_user, access_token=self.admin_user_tok,
945 "GET",
946 self.url_other_user,
947 access_token=self.admin_user_tok,
908948 )
909949
910950 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
928968
929969 # Get user
930970 channel = self.make_request(
931 "GET", self.url_other_user, access_token=self.admin_user_tok,
971 "GET",
972 self.url_other_user,
973 access_token=self.admin_user_tok,
932974 )
933975
934976 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
941983 self._is_erased("@user:test", False)
942984
943985 def _is_erased(self, user_id: str, expect: bool) -> None:
944 """Assert that the user is erased or not
945 """
986 """Assert that the user is erased or not"""
946987 d = self.store.is_user_erased(user_id)
947988 if expect:
948989 self.assertTrue(self.get_success(d))
9761017 """
9771018 url = "/_synapse/admin/v2/users/@bob:test"
9781019
979 channel = self.make_request("GET", url, access_token=self.other_user_token,)
1020 channel = self.make_request(
1021 "GET",
1022 url,
1023 access_token=self.other_user_token,
1024 )
9801025
9811026 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
9821027 self.assertEqual("You are not a server admin", channel.json_body["error"])
9831028
9841029 channel = self.make_request(
985 "PUT", url, access_token=self.other_user_token, content=b"{}",
1030 "PUT",
1031 url,
1032 access_token=self.other_user_token,
1033 content=b"{}",
9861034 )
9871035
9881036 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
10351083 self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
10361084
10371085 # Get user
1038 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1086 channel = self.make_request(
1087 "GET",
1088 url,
1089 access_token=self.admin_user_tok,
1090 )
10391091
10401092 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
10411093 self.assertEqual("@bob:test", channel.json_body["name"])
10801132 self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
10811133
10821134 # Get user
1083 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1135 channel = self.make_request(
1136 "GET",
1137 url,
1138 access_token=self.admin_user_tok,
1139 )
10841140
10851141 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
10861142 self.assertEqual("@bob:test", channel.json_body["name"])
10901146 self.assertEqual(False, channel.json_body["admin"])
10911147 self.assertEqual(False, channel.json_body["is_guest"])
10921148 self.assertEqual(False, channel.json_body["deactivated"])
1149 self.assertEqual(False, channel.json_body["shadow_banned"])
10931150 self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
10941151
10951152 @override_config(
13051362
13061363 # Get user
13071364 channel = self.make_request(
1308 "GET", self.url_other_user, access_token=self.admin_user_tok,
1365 "GET",
1366 self.url_other_user,
1367 access_token=self.admin_user_tok,
13091368 )
13101369
13111370 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
13361395
13371396 # Get user
13381397 channel = self.make_request(
1339 "GET", self.url_other_user, access_token=self.admin_user_tok,
1398 "GET",
1399 self.url_other_user,
1400 access_token=self.admin_user_tok,
13401401 )
13411402
13421403 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
13591420
13601421 # Get user
13611422 channel = self.make_request(
1362 "GET", self.url_other_user, access_token=self.admin_user_tok,
1423 "GET",
1424 self.url_other_user,
1425 access_token=self.admin_user_tok,
13631426 )
13641427
13651428 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
13891452
13901453 # Get user
13911454 channel = self.make_request(
1392 "GET", self.url_other_user, access_token=self.admin_user_tok,
1455 "GET",
1456 self.url_other_user,
1457 access_token=self.admin_user_tok,
13931458 )
13941459
13951460 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
14871552
14881553 # Get user
14891554 channel = self.make_request(
1490 "GET", self.url_other_user, access_token=self.admin_user_tok,
1555 "GET",
1556 self.url_other_user,
1557 access_token=self.admin_user_tok,
14911558 )
14921559
14931560 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
15161583
15171584 # Get user
15181585 channel = self.make_request(
1519 "GET", self.url_other_user, access_token=self.admin_user_tok,
1586 "GET",
1587 self.url_other_user,
1588 access_token=self.admin_user_tok,
15201589 )
15211590
15221591 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
15451614 self.assertEqual("bob", channel.json_body["displayname"])
15461615
15471616 # Get user
1548 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1617 channel = self.make_request(
1618 "GET",
1619 url,
1620 access_token=self.admin_user_tok,
1621 )
15491622
15501623 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
15511624 self.assertEqual("@bob:test", channel.json_body["name"])
15651638 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
15661639
15671640 # Check user is not deactivated
1568 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1641 channel = self.make_request(
1642 "GET",
1643 url,
1644 access_token=self.admin_user_tok,
1645 )
15691646
15701647 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
15711648 self.assertEqual("@bob:test", channel.json_body["name"])
15751652 self.assertEqual(0, channel.json_body["deactivated"])
15761653
15771654 def _is_erased(self, user_id, expect):
1578 """Assert that the user is erased or not
1579 """
1655 """Assert that the user is erased or not"""
15801656 d = self.store.is_user_erased(user_id)
15811657 if expect:
15821658 self.assertTrue(self.get_success(d))
16161692 """
16171693 other_user_token = self.login("user", "pass")
16181694
1619 channel = self.make_request("GET", self.url, access_token=other_user_token,)
1695 channel = self.make_request(
1696 "GET",
1697 self.url,
1698 access_token=other_user_token,
1699 )
16201700
16211701 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
16221702 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
16261706 Tests that a lookup for a user that does not exist returns an empty list
16271707 """
16281708 url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
1629 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1709 channel = self.make_request(
1710 "GET",
1711 url,
1712 access_token=self.admin_user_tok,
1713 )
16301714
16311715 self.assertEqual(200, channel.code, msg=channel.json_body)
16321716 self.assertEqual(0, channel.json_body["total"])
16381722 """
16391723 url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
16401724
1641 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1725 channel = self.make_request(
1726 "GET",
1727 url,
1728 access_token=self.admin_user_tok,
1729 )
16421730
16431731 self.assertEqual(200, channel.code, msg=channel.json_body)
16441732 self.assertEqual(0, channel.json_body["total"])
16501738 if user has no memberships
16511739 """
16521740 # Get rooms
1653 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
1741 channel = self.make_request(
1742 "GET",
1743 self.url,
1744 access_token=self.admin_user_tok,
1745 )
16541746
16551747 self.assertEqual(200, channel.code, msg=channel.json_body)
16561748 self.assertEqual(0, channel.json_body["total"])
16671759 self.helper.create_room_as(self.other_user, tok=other_user_tok)
16681760
16691761 # Get rooms
1670 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
1762 channel = self.make_request(
1763 "GET",
1764 self.url,
1765 access_token=self.admin_user_tok,
1766 )
16711767
16721768 self.assertEqual(200, channel.code, msg=channel.json_body)
16731769 self.assertEqual(number_rooms, channel.json_body["total"])
17101806
17111807 # Now get rooms
17121808 url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
1713 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1809 channel = self.make_request(
1810 "GET",
1811 url,
1812 access_token=self.admin_user_tok,
1813 )
17141814
17151815 self.assertEqual(200, channel.code, msg=channel.json_body)
17161816 self.assertEqual(1, channel.json_body["total"])
17501850 """
17511851 other_user_token = self.login("user", "pass")
17521852
1753 channel = self.make_request("GET", self.url, access_token=other_user_token,)
1853 channel = self.make_request(
1854 "GET",
1855 self.url,
1856 access_token=other_user_token,
1857 )
17541858
17551859 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
17561860 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
17601864 Tests that a lookup for a user that does not exist returns a 404
17611865 """
17621866 url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
1763 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1867 channel = self.make_request(
1868 "GET",
1869 url,
1870 access_token=self.admin_user_tok,
1871 )
17641872
17651873 self.assertEqual(404, channel.code, msg=channel.json_body)
17661874 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
17711879 """
17721880 url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
17731881
1774 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1882 channel = self.make_request(
1883 "GET",
1884 url,
1885 access_token=self.admin_user_tok,
1886 )
17751887
17761888 self.assertEqual(400, channel.code, msg=channel.json_body)
17771889 self.assertEqual("Can only lookup local users", channel.json_body["error"])
17821894 """
17831895
17841896 # Get pushers
1785 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
1897 channel = self.make_request(
1898 "GET",
1899 self.url,
1900 access_token=self.admin_user_tok,
1901 )
17861902
17871903 self.assertEqual(200, channel.code, msg=channel.json_body)
17881904 self.assertEqual(0, channel.json_body["total"])
18091925 )
18101926
18111927 # Get pushers
1812 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
1928 channel = self.make_request(
1929 "GET",
1930 self.url,
1931 access_token=self.admin_user_tok,
1932 )
18131933
18141934 self.assertEqual(200, channel.code, msg=channel.json_body)
18151935 self.assertEqual(1, channel.json_body["total"])
18581978 """
18591979 other_user_token = self.login("user", "pass")
18601980
1861 channel = self.make_request("GET", self.url, access_token=other_user_token,)
1981 channel = self.make_request(
1982 "GET",
1983 self.url,
1984 access_token=other_user_token,
1985 )
18621986
18631987 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
18641988 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
18681992 Tests that a lookup for a user that does not exist returns a 404
18691993 """
18701994 url = "/_synapse/admin/v1/users/@unknown_person:test/media"
1871 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
1995 channel = self.make_request(
1996 "GET",
1997 url,
1998 access_token=self.admin_user_tok,
1999 )
18722000
18732001 self.assertEqual(404, channel.code, msg=channel.json_body)
18742002 self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
18792007 """
18802008 url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
18812009
1882 channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
2010 channel = self.make_request(
2011 "GET",
2012 url,
2013 access_token=self.admin_user_tok,
2014 )
18832015
18842016 self.assertEqual(400, channel.code, msg=channel.json_body)
18852017 self.assertEqual("Can only lookup local users", channel.json_body["error"])
18942026 self._create_media(other_user_tok, number_media)
18952027
18962028 channel = self.make_request(
1897 "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
2029 "GET",
2030 self.url + "?limit=5",
2031 access_token=self.admin_user_tok,
18982032 )
18992033
19002034 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
19132047 self._create_media(other_user_tok, number_media)
19142048
19152049 channel = self.make_request(
1916 "GET", self.url + "?from=5", access_token=self.admin_user_tok,
2050 "GET",
2051 self.url + "?from=5",
2052 access_token=self.admin_user_tok,
19172053 )
19182054
19192055 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
19322068 self._create_media(other_user_tok, number_media)
19332069
19342070 channel = self.make_request(
1935 "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
2071 "GET",
2072 self.url + "?from=5&limit=10",
2073 access_token=self.admin_user_tok,
19362074 )
19372075
19382076 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
19472085 """
19482086
19492087 channel = self.make_request(
1950 "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
2088 "GET",
2089 self.url + "?limit=-5",
2090 access_token=self.admin_user_tok,
19512091 )
19522092
19532093 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
19592099 """
19602100
19612101 channel = self.make_request(
1962 "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
2102 "GET",
2103 self.url + "?from=-5",
2104 access_token=self.admin_user_tok,
19632105 )
19642106
19652107 self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
19772119 # `next_token` does not appear
19782120 # Number of results is the number of entries
19792121 channel = self.make_request(
1980 "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
2122 "GET",
2123 self.url + "?limit=20",
2124 access_token=self.admin_user_tok,
19812125 )
19822126
19832127 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
19882132 # `next_token` does not appear
19892133 # Number of max results is larger than the number of entries
19902134 channel = self.make_request(
1991 "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
2135 "GET",
2136 self.url + "?limit=21",
2137 access_token=self.admin_user_tok,
19922138 )
19932139
19942140 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
19992145 # `next_token` does appear
20002146 # Number of max results is smaller than the number of entries
20012147 channel = self.make_request(
2002 "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
2148 "GET",
2149 self.url + "?limit=19",
2150 access_token=self.admin_user_tok,
20032151 )
20042152
20052153 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
20112159 # Set `from` to value of `next_token` for request remaining entries
20122160 # `next_token` does not appear
20132161 channel = self.make_request(
2014 "GET", self.url + "?from=19", access_token=self.admin_user_tok,
2162 "GET",
2163 self.url + "?from=19",
2164 access_token=self.admin_user_tok,
20152165 )
20162166
20172167 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
20252175 if user has no media created
20262176 """
20272177
2028 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
2178 channel = self.make_request(
2179 "GET",
2180 self.url,
2181 access_token=self.admin_user_tok,
2182 )
20292183
20302184 self.assertEqual(200, channel.code, msg=channel.json_body)
20312185 self.assertEqual(0, channel.json_body["total"])
20402194 other_user_tok = self.login("user", "pass")
20412195 self._create_media(other_user_tok, number_media)
20422196
2043 channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
2197 channel = self.make_request(
2198 "GET",
2199 self.url,
2200 access_token=self.admin_user_tok,
2201 )
20442202
20452203 self.assertEqual(200, channel.code, msg=channel.json_body)
20462204 self.assertEqual(number_media, channel.json_body["total"])
20672225 )
20682226
20692227 def _check_fields(self, content):
2070 """Checks that all attributes are present in content
2071 """
2228 """Checks that all attributes are present in content"""
20722229 for m in content:
20732230 self.assertIn("media_id", m)
20742231 self.assertIn("media_type", m)
20812238
20822239
20832240 class UserTokenRestTestCase(unittest.HomeserverTestCase):
2084 """Test for /_synapse/admin/v1/users/<user>/login
2085 """
2241 """Test for /_synapse/admin/v1/users/<user>/login"""
20862242
20872243 servlets = [
20882244 synapse.rest.admin.register_servlets,
21132269 return channel.json_body["access_token"]
21142270
21152271 def test_no_auth(self):
2116 """Try to login as a user without authentication.
2117 """
2272 """Try to login as a user without authentication."""
21182273 channel = self.make_request("POST", self.url, b"{}")
21192274
21202275 self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
21212276 self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
21222277
21232278 def test_not_admin(self):
2124 """Try to login as a user as a non-admin user.
2125 """
2279 """Try to login as a user as a non-admin user."""
21262280 channel = self.make_request(
21272281 "POST", self.url, b"{}", access_token=self.other_user_tok
21282282 )
21302284 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
21312285
21322286 def test_send_event(self):
2133 """Test that sending event as a user works.
2134 """
2287 """Test that sending event as a user works."""
21352288 # Create a room.
21362289 room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
21372290
21452298 self.assertEqual(event.sender, self.other_user)
21462299
21472300 def test_devices(self):
2148 """Tests that logging in as a user doesn't create a new device for them.
2149 """
2301 """Tests that logging in as a user doesn't create a new device for them."""
21502302 # Login in as the user
21512303 self._get_token()
21522304
21602312 self.assertEqual(len(channel.json_body["devices"]), 1)
21612313
21622314 def test_logout(self):
2163 """Test that calling `/logout` with the token works.
2164 """
2315 """Test that calling `/logout` with the token works."""
21652316 # Login in as the user
21662317 puppet_token = self._get_token()
21672318
22512402 }
22522403 )
22532404 def test_consent(self):
2254 """Test that sending a message is not subject to the privacy policies.
2255 """
2405 """Test that sending a message is not subject to the privacy policies."""
22562406 # Have the admin user accept the terms.
22572407 self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
22582408
23272477 self.register_user("user2", "pass")
23282478 other_user2_token = self.login("user2", "pass")
23292479
2330 channel = self.make_request("GET", self.url1, access_token=other_user2_token,)
2480 channel = self.make_request(
2481 "GET",
2482 self.url1,
2483 access_token=other_user2_token,
2484 )
23312485 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
23322486 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
23332487
2334 channel = self.make_request("GET", self.url2, access_token=other_user2_token,)
2488 channel = self.make_request(
2489 "GET",
2490 self.url2,
2491 access_token=other_user2_token,
2492 )
23352493 self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
23362494 self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
23372495
23422500 url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
23432501 url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
23442502
2345 channel = self.make_request("GET", url1, access_token=self.admin_user_tok,)
2503 channel = self.make_request(
2504 "GET",
2505 url1,
2506 access_token=self.admin_user_tok,
2507 )
23462508 self.assertEqual(400, channel.code, msg=channel.json_body)
23472509 self.assertEqual("Can only whois a local user", channel.json_body["error"])
23482510
2349 channel = self.make_request("GET", url2, access_token=self.admin_user_tok,)
2511 channel = self.make_request(
2512 "GET",
2513 url2,
2514 access_token=self.admin_user_tok,
2515 )
23502516 self.assertEqual(400, channel.code, msg=channel.json_body)
23512517 self.assertEqual("Can only whois a local user", channel.json_body["error"])
23522518
23542520 """
23552521 The lookup should succeed for an admin.
23562522 """
2357 channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,)
2523 channel = self.make_request(
2524 "GET",
2525 self.url1,
2526 access_token=self.admin_user_tok,
2527 )
23582528 self.assertEqual(200, channel.code, msg=channel.json_body)
23592529 self.assertEqual(self.other_user, channel.json_body["user_id"])
23602530 self.assertIn("devices", channel.json_body)
23612531
2362 channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,)
2532 channel = self.make_request(
2533 "GET",
2534 self.url2,
2535 access_token=self.admin_user_tok,
2536 )
23632537 self.assertEqual(200, channel.code, msg=channel.json_body)
23642538 self.assertEqual(self.other_user, channel.json_body["user_id"])
23652539 self.assertIn("devices", channel.json_body)
23702544 """
23712545 other_user_token = self.login("user", "pass")
23722546
2373 channel = self.make_request("GET", self.url1, access_token=other_user_token,)
2547 channel = self.make_request(
2548 "GET",
2549 self.url1,
2550 access_token=other_user_token,
2551 )
23742552 self.assertEqual(200, channel.code, msg=channel.json_body)
23752553 self.assertEqual(self.other_user, channel.json_body["user_id"])
23762554 self.assertIn("devices", channel.json_body)
23772555
2378 channel = self.make_request("GET", self.url2, access_token=other_user_token,)
2556 channel = self.make_request(
2557 "GET",
2558 self.url2,
2559 access_token=other_user_token,
2560 )
23792561 self.assertEqual(200, channel.code, msg=channel.json_body)
23802562 self.assertEqual(self.other_user, channel.json_body["user_id"])
23812563 self.assertIn("devices", channel.json_body)
7272
7373 # Mod the mod
7474 room_power_levels = self.helper.get_state(
75 self.room_id, "m.room.power_levels", tok=self.admin_access_token,
75 self.room_id,
76 "m.room.power_levels",
77 tok=self.admin_access_token,
7678 )
7779
7880 # Update existing power levels with mod at PL50
180180 )
181181
182182 def test_redact_event_as_moderator_ratelimit(self):
183 """Tests that the correct ratelimiting is applied to redactions
184 """
183 """Tests that the correct ratelimiting is applied to redactions"""
185184
186185 message_ids = []
187186 # as a regular user, send messages to redact
249249 mock_federation_client = Mock(spec=["backfill"])
250250
251251 self.hs = self.setup_test_homeserver(
252 config=config, federation_client=mock_federation_client,
252 config=config,
253 federation_client=mock_federation_client,
253254 )
254255 return self.hs
255256
259259 message_handler = self.hs.get_message_handler()
260260 event = self.get_success(
261261 message_handler.get_room_data(
262 self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
262 self.banned_user_id,
263 room_id,
264 "m.room.member",
265 self.banned_user_id,
263266 )
264267 )
265268 self.assertEqual(
291294 message_handler = self.hs.get_message_handler()
292295 event = self.get_success(
293296 message_handler.get_room_data(
294 self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
297 self.banned_user_id,
298 room_id,
299 "m.room.member",
300 self.banned_user_id,
295301 )
296302 )
297303 self.assertEqual(
149149 event_id = resp["event_id"]
150150
151151 channel = self.make_request(
152 "GET", "/events/" + event_id, access_token=self.token,
152 "GET",
153 "/events/" + event_id,
154 access_token=self.token,
153155 )
154156 self.assertEquals(channel.code, 200, msg=channel.result)
1414
1515 import time
1616 import urllib.parse
17 from typing import Any, Dict, Union
17 from typing import Any, Dict, List, Union
1818 from urllib.parse import urlencode
1919
2020 from mock import Mock
492492 self.assertEqual(channel.code, 200, channel.result)
493493
494494 # parse the form to check it has fields assumed elsewhere in this class
495 html = channel.result["body"].decode("utf-8")
495496 p = TestHtmlParser()
496 p.feed(channel.result["body"].decode("utf-8"))
497 p.feed(html)
497498 p.close()
498499
499 self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
500
501 self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
500 # there should be a link for each href
501 returned_idps = [] # type: List[str]
502 for link in p.links:
503 path, query = link.split("?", 1)
504 self.assertEqual(path, "pick_idp")
505 params = urllib.parse.parse_qs(query)
506 self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
507 returned_idps.append(params["idp"][0])
508
509 self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
502510
503511 def test_multi_sso_redirect_to_cas(self):
504512 """If CAS is chosen, should redirect to the CAS server"""
602610 # matrix access token, mxid, and device id.
603611 login_token = params[2][1]
604612 chan = self.make_request(
605 "POST", "/login", content={"type": "m.login.token", "token": login_token},
613 "POST",
614 "/login",
615 content={"type": "m.login.token", "token": login_token},
606616 )
607617 self.assertEqual(chan.code, 200, chan.result)
608618 self.assertEqual(chan.json_body["user_id"], "@user1:test")
610620 def test_multi_sso_redirect_to_unknown(self):
611621 """An unknown IdP should cause a 400"""
612622 channel = self.make_request(
613 "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
623 "GET",
624 "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
614625 )
615626 self.assertEqual(channel.code, 400, channel.result)
616627
710721 mocked_http_client.get_raw.side_effect = get_raw
711722
712723 self.hs = self.setup_test_homeserver(
713 config=config, proxied_http_client=mocked_http_client,
724 config=config,
725 proxied_http_client=mocked_http_client,
714726 )
715727
716728 return self.hs
12351247 # looks ok.
12361248 username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
12371249 self.assertIn(
1238 session_id, username_mapping_sessions, "session id not found in map",
1250 session_id,
1251 username_mapping_sessions,
1252 "session id not found in map",
12391253 )
12401254 session = username_mapping_sessions[session_id]
12411255 self.assertEqual(session.remote_user_id, "tester")
12901304 # finally, submit the matrix login token to the login API, which gives us our
12911305 # matrix access token, mxid, and device id.
12921306 chan = self.make_request(
1293 "POST", "/login", content={"type": "m.login.token", "token": login_token},
1307 "POST",
1308 "/login",
1309 content={"type": "m.login.token", "token": login_token},
12941310 )
12951311 self.assertEqual(chan.code, 200, chan.result)
12961312 self.assertEqual(chan.json_body["user_id"], "@bobby:test")
1313 # limitations under the License.
1414
1515 """Tests REST events for /profile paths."""
16 import json
17
18 from mock import Mock
19
20 from twisted.internet import defer
21
22 import synapse.types
23 from synapse.api.errors import AuthError, SynapseError
2416 from synapse.rest import admin
2517 from synapse.rest.client.v1 import login, profile, room
2618
2719 from tests import unittest
28
29 from ....utils import MockHttpResource, setup_test_homeserver
30
31 myid = "@1234ABCD:test"
32 PATH_PREFIX = "/_matrix/client/r0"
33
34
35 class MockHandlerProfileTestCase(unittest.TestCase):
36 """ Tests rest layer of profile management.
37
38 Todo: move these into ProfileTestCase
39 """
40
41 @defer.inlineCallbacks
42 def setUp(self):
43 self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
44 self.mock_handler = Mock(
45 spec=[
46 "get_displayname",
47 "set_displayname",
48 "get_avatar_url",
49 "set_avatar_url",
50 "check_profile_query_allowed",
51 ]
52 )
53
54 self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
55 self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
56 self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
57 self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
58 self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
59 Mock()
60 )
61
62 hs = yield setup_test_homeserver(
63 self.addCleanup,
64 "test",
65 federation_http_client=None,
66 resource_for_client=self.mock_resource,
67 federation=Mock(),
68 federation_client=Mock(),
69 profile_handler=self.mock_handler,
70 )
71
72 async def _get_user_by_req(request=None, allow_guest=False):
73 return synapse.types.create_requester(myid)
74
75 hs.get_auth().get_user_by_req = _get_user_by_req
76
77 profile.register_servlets(hs, self.mock_resource)
78
79 @defer.inlineCallbacks
80 def test_get_my_name(self):
81 mocked_get = self.mock_handler.get_displayname
82 mocked_get.return_value = defer.succeed("Frank")
83
84 (code, response) = yield self.mock_resource.trigger(
85 "GET", "/profile/%s/displayname" % (myid), None
86 )
87
88 self.assertEquals(200, code)
89 self.assertEquals({"displayname": "Frank"}, response)
90 self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
91
92 @defer.inlineCallbacks
93 def test_set_my_name(self):
94 mocked_set = self.mock_handler.set_displayname
95 mocked_set.return_value = defer.succeed(())
96
97 (code, response) = yield self.mock_resource.trigger(
98 "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}'
99 )
100
101 self.assertEquals(200, code)
102 self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
103 self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
104 self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.")
105
106 @defer.inlineCallbacks
107 def test_set_my_name_noauth(self):
108 mocked_set = self.mock_handler.set_displayname
109 mocked_set.side_effect = AuthError(400, "message")
110
111 (code, response) = yield self.mock_resource.trigger(
112 "PUT",
113 "/profile/%s/displayname" % ("@4567:test"),
114 b'{"displayname": "Frank Jr."}',
115 )
116
117 self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code))
118
119 @defer.inlineCallbacks
120 def test_get_other_name(self):
121 mocked_get = self.mock_handler.get_displayname
122 mocked_get.return_value = defer.succeed("Bob")
123
124 (code, response) = yield self.mock_resource.trigger(
125 "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None
126 )
127
128 self.assertEquals(200, code)
129 self.assertEquals({"displayname": "Bob"}, response)
130
131 @defer.inlineCallbacks
132 def test_set_other_name(self):
133 mocked_set = self.mock_handler.set_displayname
134 mocked_set.side_effect = SynapseError(400, "message")
135
136 (code, response) = yield self.mock_resource.trigger(
137 "PUT",
138 "/profile/%s/displayname" % ("@opaque:elsewhere"),
139 b'{"displayname":"bob"}',
140 )
141
142 self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code))
143
144 @defer.inlineCallbacks
145 def test_get_my_avatar(self):
146 mocked_get = self.mock_handler.get_avatar_url
147 mocked_get.return_value = defer.succeed("http://my.server/me.png")
148
149 (code, response) = yield self.mock_resource.trigger(
150 "GET", "/profile/%s/avatar_url" % (myid), None
151 )
152
153 self.assertEquals(200, code)
154 self.assertEquals({"avatar_url": "http://my.server/me.png"}, response)
155 self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
156
157 @defer.inlineCallbacks
158 def test_set_my_avatar(self):
159 mocked_set = self.mock_handler.set_avatar_url
160 mocked_set.return_value = defer.succeed(())
161
162 (code, response) = yield self.mock_resource.trigger(
163 "PUT",
164 "/profile/%s/avatar_url" % (myid),
165 b'{"avatar_url": "http://my.server/pic.gif"}',
166 )
167
168 self.assertEquals(200, code)
169 self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
170 self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
171 self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
17220
17321
17422 class ProfileTestCase(unittest.HomeserverTestCase):
18634 def prepare(self, reactor, clock, hs):
18735 self.owner = self.register_user("owner", "pass")
18836 self.owner_tok = self.login("owner", "pass")
37 self.other = self.register_user("other", "pass", displayname="Bob")
38
39 def test_get_displayname(self):
40 res = self._get_displayname()
41 self.assertEqual(res, "owner")
18942
19043 def test_set_displayname(self):
19144 channel = self.make_request(
19245 "PUT",
19346 "/profile/%s/displayname" % (self.owner,),
194 content=json.dumps({"displayname": "test"}),
195 access_token=self.owner_tok,
196 )
197 self.assertEqual(channel.code, 200, channel.result)
198
199 res = self.get_displayname()
47 content={"displayname": "test"},
48 access_token=self.owner_tok,
49 )
50 self.assertEqual(channel.code, 200, channel.result)
51
52 res = self._get_displayname()
20053 self.assertEqual(res, "test")
54
55 def test_set_displayname_noauth(self):
56 channel = self.make_request(
57 "PUT",
58 "/profile/%s/displayname" % (self.owner,),
59 content={"displayname": "test"},
60 )
61 self.assertEqual(channel.code, 401, channel.result)
20162
20263 def test_set_displayname_too_long(self):
20364 """Attempts to set a stupid displayname should get a 400"""
20465 channel = self.make_request(
20566 "PUT",
20667 "/profile/%s/displayname" % (self.owner,),
207 content=json.dumps({"displayname": "test" * 100}),
208 access_token=self.owner_tok,
209 )
210 self.assertEqual(channel.code, 400, channel.result)
211
212 res = self.get_displayname()
68 content={"displayname": "test" * 100},
69 access_token=self.owner_tok,
70 )
71 self.assertEqual(channel.code, 400, channel.result)
72
73 res = self._get_displayname()
21374 self.assertEqual(res, "owner")
21475
215 def get_displayname(self):
216 channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,))
76 def test_get_displayname_other(self):
77 res = self._get_displayname(self.other)
78 self.assertEquals(res, "Bob")
79
80 def test_set_displayname_other(self):
81 channel = self.make_request(
82 "PUT",
83 "/profile/%s/displayname" % (self.other,),
84 content={"displayname": "test"},
85 access_token=self.owner_tok,
86 )
87 self.assertEqual(channel.code, 400, channel.result)
88
89 def test_get_avatar_url(self):
90 res = self._get_avatar_url()
91 self.assertIsNone(res)
92
93 def test_set_avatar_url(self):
94 channel = self.make_request(
95 "PUT",
96 "/profile/%s/avatar_url" % (self.owner,),
97 content={"avatar_url": "http://my.server/pic.gif"},
98 access_token=self.owner_tok,
99 )
100 self.assertEqual(channel.code, 200, channel.result)
101
102 res = self._get_avatar_url()
103 self.assertEqual(res, "http://my.server/pic.gif")
104
105 def test_set_avatar_url_noauth(self):
106 channel = self.make_request(
107 "PUT",
108 "/profile/%s/avatar_url" % (self.owner,),
109 content={"avatar_url": "http://my.server/pic.gif"},
110 )
111 self.assertEqual(channel.code, 401, channel.result)
112
113 def test_set_avatar_url_too_long(self):
114 """Attempts to set a stupid avatar_url should get a 400"""
115 channel = self.make_request(
116 "PUT",
117 "/profile/%s/avatar_url" % (self.owner,),
118 content={"avatar_url": "http://my.server/pic.gif" * 100},
119 access_token=self.owner_tok,
120 )
121 self.assertEqual(channel.code, 400, channel.result)
122
123 res = self._get_avatar_url()
124 self.assertIsNone(res)
125
126 def test_get_avatar_url_other(self):
127 res = self._get_avatar_url(self.other)
128 self.assertIsNone(res)
129
130 def test_set_avatar_url_other(self):
131 channel = self.make_request(
132 "PUT",
133 "/profile/%s/avatar_url" % (self.other,),
134 content={"avatar_url": "http://my.server/pic.gif"},
135 access_token=self.owner_tok,
136 )
137 self.assertEqual(channel.code, 400, channel.result)
138
139 def _get_displayname(self, name=None):
140 channel = self.make_request(
141 "GET", "/profile/%s/displayname" % (name or self.owner,)
142 )
217143 self.assertEqual(channel.code, 200, channel.result)
218144 return channel.json_body["displayname"]
145
146 def _get_avatar_url(self, name=None):
147 channel = self.make_request(
148 "GET", "/profile/%s/avatar_url" % (name or self.owner,)
149 )
150 self.assertEqual(channel.code, 200, channel.result)
151 return channel.json_body.get("avatar_url")
219152
220153
221154 class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
4545 def make_homeserver(self, reactor, clock):
4646
4747 self.hs = self.setup_test_homeserver(
48 "red", federation_http_client=None, federation_client=Mock(),
48 "red",
49 federation_http_client=None,
50 federation_client=Mock(),
4951 )
5052
5153 self.hs.get_federation_handler = Mock()
14791481 results = channel.json_body["search_categories"]["room_events"]["results"]
14801482
14811483 self.assertEqual(
1482 len(results), 2, [result["result"]["content"] for result in results],
1484 len(results),
1485 2,
1486 [result["result"]["content"] for result in results],
14831487 )
14841488 self.assertEqual(
14851489 results[0]["result"]["content"]["body"],
15141518 results = channel.json_body["search_categories"]["room_events"]["results"]
15151519
15161520 self.assertEqual(
1517 len(results), 4, [result["result"]["content"] for result in results],
1521 len(results),
1522 4,
1523 [result["result"]["content"] for result in results],
15181524 )
15191525 self.assertEqual(
15201526 results[0]["result"]["content"]["body"],
15611567 results = channel.json_body["search_categories"]["room_events"]["results"]
15621568
15631569 self.assertEqual(
1564 len(results), 1, [result["result"]["content"] for result in results],
1570 len(results),
1571 1,
1572 [result["result"]["content"] for result in results],
15651573 )
15661574 self.assertEqual(
15671575 results[0]["result"]["content"]["body"],
1717
1818 from mock import Mock
1919
20 from twisted.internet import defer
21
2220 from synapse.rest.client.v1 import room
2321 from synapse.types import UserID
2422
3836 def make_homeserver(self, reactor, clock):
3937
4038 hs = self.setup_test_homeserver(
41 "red", federation_http_client=None, federation_client=Mock(),
39 "red",
40 federation_http_client=None,
41 federation_client=Mock(),
4242 )
4343
4444 self.event_source = hs.get_event_sources().sources["typing"]
5858 return None
5959
6060 hs.get_datastore().insert_client_ip = _insert_client_ip
61
62 def get_room_members(room_id):
63 if room_id == self.room_id:
64 return defer.succeed([self.user])
65 else:
66 return defer.succeed([])
67
68 @defer.inlineCallbacks
69 def fetch_room_distributions_into(
70 room_id, localusers=None, remotedomains=None, ignore_user=None
71 ):
72 members = yield get_room_members(room_id)
73 for member in members:
74 if ignore_user is not None and member == ignore_user:
75 continue
76
77 if hs.is_mine(member):
78 if localusers is not None:
79 localusers.add(member)
80 else:
81 if remotedomains is not None:
82 remotedomains.add(member.domain)
83
84 hs.get_room_member_handler().fetch_room_distributions_into = (
85 fetch_room_distributions_into
86 )
8761
8862 return hs
8963
165165 json.dumps(data).encode("utf8"),
166166 )
167167
168 assert int(channel.result["code"]) == expect_code, (
169 "Expected: %d, got: %d, resp: %r"
170 % (expect_code, int(channel.result["code"]), channel.result["body"])
168 assert (
169 int(channel.result["code"]) == expect_code
170 ), "Expected: %d, got: %d, resp: %r" % (
171 expect_code,
172 int(channel.result["code"]),
173 channel.result["body"],
171174 )
172175
173176 self.auth_user_id = temp_id
200203 json.dumps(content).encode("utf8"),
201204 )
202205
203 assert int(channel.result["code"]) == expect_code, (
204 "Expected: %d, got: %d, resp: %r"
205 % (expect_code, int(channel.result["code"]), channel.result["body"])
206 assert (
207 int(channel.result["code"]) == expect_code
208 ), "Expected: %d, got: %d, resp: %r" % (
209 expect_code,
210 int(channel.result["code"]),
211 channel.result["body"],
206212 )
207213
208214 return channel.json_body
250256
251257 channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
252258
253 assert int(channel.result["code"]) == expect_code, (
254 "Expected: %d, got: %d, resp: %r"
255 % (expect_code, int(channel.result["code"]), channel.result["body"])
259 assert (
260 int(channel.result["code"]) == expect_code
261 ), "Expected: %d, got: %d, resp: %r" % (
262 expect_code,
263 int(channel.result["code"]),
264 channel.result["body"],
256265 )
257266
258267 return channel.json_body
446455 return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
447456
448457 def complete_oidc_auth(
449 self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
458 self,
459 oauth_uri: str,
460 cookies: Mapping[str, str],
461 user_info_dict: JsonDict,
450462 ) -> FakeChannel:
451463 """Mock out an OIDC authentication flow
452464
490502 (expected_uri, resp_obj) = expected_requests.pop(0)
491503 assert uri == expected_uri
492504 resp = FakeResponse(
493 code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
505 code=200,
506 phrase=b"OK",
507 body=json.dumps(resp_obj).encode("utf-8"),
494508 )
495509 return resp
496510
7474 self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
7575
7676 def test_basic_password_reset(self):
77 """Test basic password reset flow
78 """
77 """Test basic password reset flow"""
7978 old_password = "monkey"
8079 new_password = "kangeroo"
8180
113112
114113 @override_config({"rc_3pid_validation": {"burst_count": 3}})
115114 def test_ratelimit_by_email(self):
116 """Test that we ratelimit /requestToken for the same email.
117 """
115 """Test that we ratelimit /requestToken for the same email."""
118116 old_password = "monkey"
119117 new_password = "kangeroo"
120118
202200 self.attempt_wrong_password_login("kermit", old_password)
203201
204202 def test_cant_reset_password_without_clicking_link(self):
205 """Test that we do actually need to click the link in the email
206 """
203 """Test that we do actually need to click the link in the email"""
207204 old_password = "monkey"
208205 new_password = "kangeroo"
209206
298295
299296 if channel.code != 200:
300297 raise HttpResponseException(
301 channel.code, channel.result["reason"], channel.result["body"],
298 channel.code,
299 channel.result["reason"],
300 channel.result["body"],
302301 )
303302
304303 return channel.json_body["sid"]
565564
566565 @override_config({"rc_3pid_validation": {"burst_count": 3}})
567566 def test_ratelimit_by_ip(self):
568 """Tests that adding emails is ratelimited by IP
569 """
567 """Tests that adding emails is ratelimited by IP"""
570568
571569 # We expect to be able to set three emails before getting ratelimited.
572570 self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
579577 self.assertEqual(cm.exception.code, 429)
580578
581579 def test_add_email_if_disabled(self):
582 """Test adding email to profile when doing so is disallowed
583 """
580 """Test adding email to profile when doing so is disallowed"""
584581 self.hs.config.enable_3pid_changes = False
585582
586583 client_secret = "foobar"
610607
611608 # Get user
612609 channel = self.make_request(
613 "GET", self.url_3pid, access_token=self.user_id_tok,
610 "GET",
611 self.url_3pid,
612 access_token=self.user_id_tok,
614613 )
615614
616615 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
617616 self.assertFalse(channel.json_body["threepids"])
618617
619618 def test_delete_email(self):
620 """Test deleting an email from profile
621 """
619 """Test deleting an email from profile"""
622620 # Add a threepid
623621 self.get_success(
624622 self.store.user_add_threepid(
640638
641639 # Get user
642640 channel = self.make_request(
643 "GET", self.url_3pid, access_token=self.user_id_tok,
641 "GET",
642 self.url_3pid,
643 access_token=self.user_id_tok,
644644 )
645645
646646 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
647647 self.assertFalse(channel.json_body["threepids"])
648648
649649 def test_delete_email_if_disabled(self):
650 """Test deleting an email from profile when disallowed
651 """
650 """Test deleting an email from profile when disallowed"""
652651 self.hs.config.enable_3pid_changes = False
653652
654653 # Add a threepid
674673
675674 # Get user
676675 channel = self.make_request(
677 "GET", self.url_3pid, access_token=self.user_id_tok,
676 "GET",
677 self.url_3pid,
678 access_token=self.user_id_tok,
678679 )
679680
680681 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
682683 self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
683684
684685 def test_cant_add_email_without_clicking_link(self):
685 """Test that we do actually need to click the link in the email
686 """
686 """Test that we do actually need to click the link in the email"""
687687 client_secret = "foobar"
688688 session_id = self._request_token(self.email, client_secret)
689689
709709
710710 # Get user
711711 channel = self.make_request(
712 "GET", self.url_3pid, access_token=self.user_id_tok,
712 "GET",
713 self.url_3pid,
714 access_token=self.user_id_tok,
713715 )
714716
715717 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
742744
743745 # Get user
744746 channel = self.make_request(
745 "GET", self.url_3pid, access_token=self.user_id_tok,
747 "GET",
748 self.url_3pid,
749 access_token=self.user_id_tok,
746750 )
747751
748752 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
787791
788792 # Ensure not providing a next_link parameter still works
789793 self._request_token(
790 "something@example.com", "some_secret", next_link=None, expect_code=200,
794 "something@example.com",
795 "some_secret",
796 next_link=None,
797 expect_code=200,
791798 )
792799
793800 self._request_token(
845852 if next_link:
846853 body["next_link"] = next_link
847854
848 channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
855 channel = self.make_request(
856 "POST",
857 b"account/3pid/email/requestToken",
858 body,
859 )
849860
850861 if channel.code != expect_code:
851862 raise HttpResponseException(
852 channel.code, channel.result["reason"], channel.result["body"],
863 channel.code,
864 channel.result["reason"],
865 channel.result["body"],
853866 )
854867
855868 return channel.json_body.get("sid")
856869
857870 def _request_token_invalid_email(
858 self, email, expected_errcode, expected_error, client_secret="foobar",
871 self,
872 email,
873 expected_errcode,
874 expected_error,
875 client_secret="foobar",
859876 ):
860877 channel = self.make_request(
861878 "POST",
894911 return match.group(0)
895912
896913 def _add_email(self, request_email, expected_email):
897 """Test adding an email to profile
898 """
914 """Test adding an email to profile"""
899915 previous_email_attempts = len(self.email_attempts)
900916
901917 client_secret = "foobar"
925941
926942 # Get user
927943 channel = self.make_request(
928 "GET", self.url_3pid, access_token=self.user_id_tok,
944 "GET",
945 self.url_3pid,
946 access_token=self.user_id_tok,
929947 )
930948
931949 self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
101101 """Ensure that fallback auth via a captcha works."""
102102 # Returns a 401 as per the spec
103103 channel = self.register(
104 401, {"username": "user", "type": "m.login.password", "password": "bar"},
104 401,
105 {"username": "user", "type": "m.login.password", "password": "bar"},
105106 )
106107
107108 # Grab the session
190191 ) -> FakeChannel:
191192 """Delete an individual device."""
192193 channel = self.make_request(
193 "DELETE", "devices/" + device, body, access_token=access_token,
194 "DELETE",
195 "devices/" + device,
196 body,
197 access_token=access_token,
194198 )
195199
196200 # Ensure the response is sane.
203207 # Note that this uses the delete_devices endpoint so that we can modify
204208 # the payload half-way through some tests.
205209 channel = self.make_request(
206 "POST", "delete_devices", body, access_token=self.user_tok,
210 "POST",
211 "delete_devices",
212 body,
213 access_token=self.user_tok,
207214 )
208215
209216 # Ensure the response is sane.
335342 },
336343 )
337344
338 @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}})
345 @unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
339346 def test_can_reuse_session(self):
340347 """
341348 The session can be reused if configured.
416423
417424 # and now the delete request should succeed.
418425 self.delete_device(
419 self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}},
426 self.user_tok,
427 self.device_id,
428 200,
429 body={"auth": {"session": session_id}},
420430 )
421431
422432 @skip_unless(HAS_OIDC, "requires OIDC")
442452 @skip_unless(HAS_OIDC, "requires OIDC")
443453 @override_config({"oidc_config": TEST_OIDC_CONFIG})
444454 def test_offers_both_flows_for_upgraded_user(self):
445 """A user that had a password and then logged in with SSO should get both flows
446 """
455 """A user that had a password and then logged in with SSO should get both flows"""
447456 login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
448457 self.assertEqual(login_resp["user_id"], self.user)
449458
458467 @skip_unless(HAS_OIDC, "requires OIDC")
459468 @override_config({"oidc_config": TEST_OIDC_CONFIG})
460469 def test_ui_auth_fails_for_incorrect_sso_user(self):
461 """If the user tries to authenticate with the wrong SSO user, they get an error
462 """
470 """If the user tries to authenticate with the wrong SSO user, they get an error"""
463471 # log the user in
464472 login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
465473 self.assertEqual(login_resp["user_id"], self.user)
9090
9191 self.assertEqual(channel.code, 400, channel.result)
9292 self.assertEqual(
93 channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result,
93 channel.json_body["errcode"],
94 Codes.PASSWORD_TOO_SHORT,
95 channel.result,
9496 )
9597
9698 def test_password_no_digit(self):
99101
100102 self.assertEqual(channel.code, 400, channel.result)
101103 self.assertEqual(
102 channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result,
104 channel.json_body["errcode"],
105 Codes.PASSWORD_NO_DIGIT,
106 channel.result,
103107 )
104108
105109 def test_password_no_symbol(self):
108112
109113 self.assertEqual(channel.code, 400, channel.result)
110114 self.assertEqual(
111 channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result,
115 channel.json_body["errcode"],
116 Codes.PASSWORD_NO_SYMBOL,
117 channel.result,
112118 )
113119
114120 def test_password_no_uppercase(self):
117123
118124 self.assertEqual(channel.code, 400, channel.result)
119125 self.assertEqual(
120 channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result,
126 channel.json_body["errcode"],
127 Codes.PASSWORD_NO_UPPERCASE,
128 channel.result,
121129 )
122130
123131 def test_password_no_lowercase(self):
126134
127135 self.assertEqual(channel.code, 400, channel.result)
128136 self.assertEqual(
129 channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result,
137 channel.json_body["errcode"],
138 Codes.PASSWORD_NO_LOWERCASE,
139 channel.result,
130140 )
131141
132142 def test_password_compliant(self):
8282 )
8383
8484 def test_deny_membership(self):
85 """Test that we deny relations on membership events
86 """
85 """Test that we deny relations on membership events"""
8786 channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
8887 self.assertEquals(400, channel.code, channel.json_body)
8988
9089 def test_deny_double_react(self):
91 """Test that we deny relations on membership events
92 """
90 """Test that we deny relations on membership events"""
9391 channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
9492 self.assertEquals(200, channel.code, channel.json_body)
9593
9795 self.assertEquals(400, channel.code, channel.json_body)
9896
9997 def test_basic_paginate_relations(self):
100 """Tests that calling pagination API correctly the latest relations.
101 """
98 """Tests that calling pagination API correctly the latest relations."""
10299 channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
103100 self.assertEquals(200, channel.code, channel.json_body)
104101
173170 self.assertEquals(found_event_ids, expected_event_ids)
174171
175172 def test_aggregation_pagination_groups(self):
176 """Test that we can paginate annotation groups correctly.
177 """
173 """Test that we can paginate annotation groups correctly."""
178174
179175 # We need to create ten separate users to send each reaction.
180176 access_tokens = [self.user_token, self.user2_token]
239235 self.assertEquals(sent_groups, found_groups)
240236
241237 def test_aggregation_pagination_within_group(self):
242 """Test that we can paginate within an annotation group.
243 """
238 """Test that we can paginate within an annotation group."""
244239
245240 # We need to create ten separate users to send each reaction.
246241 access_tokens = [self.user_token, self.user2_token]
310305 self.assertEquals(found_event_ids, expected_event_ids)
311306
312307 def test_aggregation(self):
313 """Test that annotations get correctly aggregated.
314 """
308 """Test that annotations get correctly aggregated."""
315309
316310 channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
317311 self.assertEquals(200, channel.code, channel.json_body)
343337 )
344338
345339 def test_aggregation_redactions(self):
346 """Test that annotations get correctly aggregated after a redaction.
347 """
340 """Test that annotations get correctly aggregated after a redaction."""
348341
349342 channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
350343 self.assertEquals(200, channel.code, channel.json_body)
378371 )
379372
380373 def test_aggregation_must_be_annotation(self):
381 """Test that aggregations must be annotations.
382 """
374 """Test that aggregations must be annotations."""
383375
384376 channel = self.make_request(
385377 "GET",
436428 )
437429
438430 def test_edit(self):
439 """Test that a simple edit works.
440 """
431 """Test that a simple edit works."""
441432
442433 new_body = {"msgtype": "m.text", "body": "I've been edited!"}
443434 channel = self._send_relation(
387387
388388 # Check that room name changes increase the unread counter.
389389 self.helper.send_state(
390 self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
390 self.room_id,
391 "m.room.name",
392 {"name": "my super room"},
393 tok=self.tok2,
391394 )
392395 self._check_unread_count(1)
393396
394397 # Check that room topic changes increase the unread counter.
395398 self.helper.send_state(
396 self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
399 self.room_id,
400 "m.room.topic",
401 {"topic": "welcome!!!"},
402 tok=self.tok2,
397403 )
398404 self._check_unread_count(2)
399405
403409
404410 # Check that custom events with a body increase the unread counter.
405411 self.helper.send_event(
406 self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
412 self.room_id,
413 "org.matrix.custom_type",
414 {"body": "hello"},
415 tok=self.tok2,
407416 )
408417 self._check_unread_count(4)
409418
442451 """Syncs and compares the unread count with the expected value."""
443452
444453 channel = self.make_request(
445 "GET", self.url % self.next_batch, access_token=self.tok,
454 "GET",
455 self.url % self.next_batch,
456 access_token=self.tok,
446457 )
447458
448459 self.assertEqual(channel.code, 200, channel.json_body)
449460
450461 room_entry = channel.json_body["rooms"]["join"][self.room_id]
451462 self.assertEqual(
452 room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
463 room_entry["org.matrix.msc2654.unread_count"],
464 expected_count,
465 room_entry,
453466 )
454467
455468 # Store the next batch for the next request.
0 # -*- coding: utf-8 -*-
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 from typing import Optional
15
16 from synapse.config.server import DEFAULT_ROOM_VERSION
17 from synapse.rest import admin
18 from synapse.rest.client.v1 import login, room
19 from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
20
21 from tests import unittest
22 from tests.server import FakeChannel
23
24
25 class UpgradeRoomTest(unittest.HomeserverTestCase):
26 servlets = [
27 admin.register_servlets,
28 login.register_servlets,
29 room.register_servlets,
30 room_upgrade_rest_servlet.register_servlets,
31 ]
32
33 def prepare(self, reactor, clock, hs):
34 self.store = hs.get_datastore()
35 self.handler = hs.get_user_directory_handler()
36
37 self.creator = self.register_user("creator", "pass")
38 self.creator_token = self.login(self.creator, "pass")
39
40 self.other = self.register_user("user", "pass")
41 self.other_token = self.login(self.other, "pass")
42
43 self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token)
44 self.helper.join(self.room_id, self.other, tok=self.other_token)
45
46 def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel:
47 # We never want a cached response.
48 self.reactor.advance(5 * 60 + 1)
49
50 return self.make_request(
51 "POST",
52 "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id,
53 # This will upgrade a room to the same version, but that's fine.
54 content={"new_version": DEFAULT_ROOM_VERSION},
55 access_token=token or self.creator_token,
56 )
57
58 def test_upgrade(self):
59 """
60 Upgrading a room should work fine.
61 """
62 channel = self._upgrade_room()
63 self.assertEquals(200, channel.code, channel.result)
64 self.assertIn("replacement_room", channel.json_body)
65
66 def test_not_in_room(self):
67 """
68 Upgrading a room should work fine.
69 """
70 # THe user isn't in the room.
71 roomless = self.register_user("roomless", "pass")
72 roomless_token = self.login(roomless, "pass")
73
74 channel = self._upgrade_room(roomless_token)
75 self.assertEquals(403, channel.code, channel.result)
76
77 def test_power_levels(self):
78 """
79 Another user can upgrade the room if their power level is increased.
80 """
81 # The other user doesn't have the proper power level.
82 channel = self._upgrade_room(self.other_token)
83 self.assertEquals(403, channel.code, channel.result)
84
85 # Increase the power levels so that this user can upgrade.
86 power_levels = self.helper.get_state(
87 self.room_id,
88 "m.room.power_levels",
89 tok=self.creator_token,
90 )
91 power_levels["users"][self.other] = 100
92 self.helper.send_state(
93 self.room_id,
94 "m.room.power_levels",
95 body=power_levels,
96 tok=self.creator_token,
97 )
98
99 # The upgrade should succeed!
100 channel = self._upgrade_room(self.other_token)
101 self.assertEquals(200, channel.code, channel.result)
102
103 def test_power_levels_user_default(self):
104 """
105 Another user can upgrade the room if the default power level for users is increased.
106 """
107 # The other user doesn't have the proper power level.
108 channel = self._upgrade_room(self.other_token)
109 self.assertEquals(403, channel.code, channel.result)
110
111 # Increase the power levels so that this user can upgrade.
112 power_levels = self.helper.get_state(
113 self.room_id,
114 "m.room.power_levels",
115 tok=self.creator_token,
116 )
117 power_levels["users_default"] = 100
118 self.helper.send_state(
119 self.room_id,
120 "m.room.power_levels",
121 body=power_levels,
122 tok=self.creator_token,
123 )
124
125 # The upgrade should succeed!
126 channel = self._upgrade_room(self.other_token)
127 self.assertEquals(200, channel.code, channel.result)
128
129 def test_power_levels_tombstone(self):
130 """
131 Another user can upgrade the room if they can send the tombstone event.
132 """
133 # The other user doesn't have the proper power level.
134 channel = self._upgrade_room(self.other_token)
135 self.assertEquals(403, channel.code, channel.result)
136
137 # Increase the power levels so that this user can upgrade.
138 power_levels = self.helper.get_state(
139 self.room_id,
140 "m.room.power_levels",
141 tok=self.creator_token,
142 )
143 power_levels["events"]["m.room.tombstone"] = 0
144 self.helper.send_state(
145 self.room_id,
146 "m.room.power_levels",
147 body=power_levels,
148 tok=self.creator_token,
149 )
150
151 # The upgrade should succeed!
152 channel = self._upgrade_room(self.other_token)
153 self.assertEquals(200, channel.code, channel.result)
154
155 power_levels = self.helper.get_state(
156 self.room_id,
157 "m.room.power_levels",
158 tok=self.creator_token,
159 )
160 self.assertNotIn(self.other, power_levels["users"])
179179 async def post_json(destination, path, data):
180180 self.assertEqual(destination, self.hs.hostname)
181181 self.assertEqual(
182 path, "/_matrix/key/v2/query",
182 path,
183 "/_matrix/key/v2/query",
183184 )
184185
185186 channel = FakeChannel(self.site, self.reactor)
187188 req.content = BytesIO(encode_canonical_json(data))
188189
189190 req.requestReceived(
190 b"POST", path.encode("utf-8"), b"1.1",
191 b"POST",
192 path.encode("utf-8"),
193 b"1.1",
191194 )
192195 channel.await_result()
193196 self.assertEqual(channel.code, 200)
2929 from twisted.internet.defer import Deferred
3030
3131 from synapse.logging.context import make_deferred_yieldable
32 from synapse.rest import admin
33 from synapse.rest.client.v1 import login
3234 from synapse.rest.media.v1._base import FileInfo
3335 from synapse.rest.media.v1.filepath import MediaFilePaths
3436 from synapse.rest.media.v1.media_storage import MediaStorage
3638
3739 from tests import unittest
3840 from tests.server import FakeSite, make_request
41 from tests.utils import default_config
3942
4043
4144 class MediaStorageTests(unittest.HomeserverTestCase):
163166 ),
164167 ),
165168 # an empty file
166 (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
169 (
170 _TestImage(
171 b"",
172 b"image/gif",
173 b".gif",
174 None,
175 None,
176 False,
177 ),
178 ),
167179 ],
168180 )
169181 class MediaRepoTests(unittest.HomeserverTestCase):
397409 headers.getRawHeaders(b"X-Robots-Tag"),
398410 [b"noindex, nofollow, noarchive, noimageindex"],
399411 )
412
413
414 class TestSpamChecker:
415 """A spam checker module that rejects all media that includes the bytes
416 `evil`.
417 """
418
419 def __init__(self, config, api):
420 self.config = config
421 self.api = api
422
423 def parse_config(config):
424 return config
425
426 async def check_event_for_spam(self, foo):
427 return False # allow all events
428
429 async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
430 return True # allow all invites
431
432 async def user_may_create_room(self, userid):
433 return True # allow all room creations
434
435 async def user_may_create_room_alias(self, userid, room_alias):
436 return True # allow all room aliases
437
438 async def user_may_publish_room(self, userid, room_id):
439 return True # allow publishing of all rooms
440
441 async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
442 buf = BytesIO()
443 await file_wrapper.write_chunks_to(buf.write)
444
445 return b"evil" in buf.getvalue()
446
447
448 class SpamCheckerTestCase(unittest.HomeserverTestCase):
449 servlets = [
450 login.register_servlets,
451 admin.register_servlets,
452 ]
453
454 def prepare(self, reactor, clock, hs):
455 self.user = self.register_user("user", "pass")
456 self.tok = self.login("user", "pass")
457
458 # Allow for uploading and downloading to/from the media repo
459 self.media_repo = hs.get_media_repository_resource()
460 self.download_resource = self.media_repo.children[b"download"]
461 self.upload_resource = self.media_repo.children[b"upload"]
462
463 def default_config(self):
464 config = default_config("test")
465
466 config.update(
467 {
468 "spam_checker": [
469 {
470 "module": TestSpamChecker.__module__ + ".TestSpamChecker",
471 "config": {},
472 }
473 ]
474 }
475 )
476
477 return config
478
479 def test_upload_innocent(self):
480 """Attempt to upload some innocent data that should be allowed."""
481
482 image_data = unhexlify(
483 b"89504e470d0a1a0a0000000d4948445200000001000000010806"
484 b"0000001f15c4890000000a49444154789c63000100000500010d"
485 b"0a2db40000000049454e44ae426082"
486 )
487
488 self.helper.upload_media(
489 self.upload_resource, image_data, tok=self.tok, expect_code=200
490 )
491
492 def test_upload_ban(self):
493 """Attempt to upload some data that includes bytes "evil", which should
494 get rejected by the spam checker.
495 """
496
497 data = b"Some evil data"
498
499 self.helper.upload_media(
500 self.upload_resource, data, tok=self.tok, expect_code=400
501 )
346346 self._tcp_callbacks[(host, port)] = callback
347347
348348 def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
349 """Fake L{IReactorTCP.connectTCP}.
350 """
349 """Fake L{IReactorTCP.connectTCP}."""
351350
352351 conn = super().connectTCP(
353352 host, port, factory, timeout=timeout, bindAddress=None
352352 tok = self.login(localpart, "password")
353353
354354 # Sync with the user's token to mark the user as active.
355 channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,)
355 channel = self.make_request(
356 "GET",
357 "/sync?timeout=0",
358 access_token=tok,
359 )
356360
357361 # Also retrieves the list of invites for this user. We don't care about that
358362 # one except if we're processing the last user, which should have received an
381381 self.do_check(events, edges, expected_state_ids)
382382
383383 def test_mainline_sort(self):
384 """Tests that the mainline ordering works correctly.
385 """
384 """Tests that the mainline ordering works correctly."""
386385
387386 events = [
388387 FakeEvent(
659658 # C -|-> B -> A
660659
661660 a = FakeEvent(
662 id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
661 id="A",
662 sender=ALICE,
663 type=EventTypes.Member,
664 state_key="",
665 content={},
663666 ).to_event([], [])
664667
665668 b = FakeEvent(
666 id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
669 id="B",
670 sender=ALICE,
671 type=EventTypes.Member,
672 state_key="",
673 content={},
667674 ).to_event([a.event_id], [])
668675
669676 c = FakeEvent(
670 id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
677 id="C",
678 sender=ALICE,
679 type=EventTypes.Member,
680 state_key="",
681 content={},
671682 ).to_event([b.event_id], [])
672683
673684 persisted_events = {a.event_id: a, b.event_id: b}
693704 # D -> C -|-> B -> A
694705
695706 a = FakeEvent(
696 id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
707 id="A",
708 sender=ALICE,
709 type=EventTypes.Member,
710 state_key="",
711 content={},
697712 ).to_event([], [])
698713
699714 b = FakeEvent(
700 id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
715 id="B",
716 sender=ALICE,
717 type=EventTypes.Member,
718 state_key="",
719 content={},
701720 ).to_event([a.event_id], [])
702721
703722 c = FakeEvent(
704 id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
723 id="C",
724 sender=ALICE,
725 type=EventTypes.Member,
726 state_key="",
727 content={},
705728 ).to_event([b.event_id], [])
706729
707730 d = FakeEvent(
708 id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
731 id="D",
732 sender=ALICE,
733 type=EventTypes.Member,
734 state_key="",
735 content={},
709736 ).to_event([c.event_id], [])
710737
711738 persisted_events = {a.event_id: a, b.event_id: b}
736763 # |
737764
738765 a = FakeEvent(
739 id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
766 id="A",
767 sender=ALICE,
768 type=EventTypes.Member,
769 state_key="",
770 content={},
740771 ).to_event([], [])
741772
742773 b = FakeEvent(
743 id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
774 id="B",
775 sender=ALICE,
776 type=EventTypes.Member,
777 state_key="",
778 content={},
744779 ).to_event([a.event_id], [])
745780
746781 c = FakeEvent(
747 id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
782 id="C",
783 sender=ALICE,
784 type=EventTypes.Member,
785 state_key="",
786 content={},
748787 ).to_event([b.event_id], [])
749788
750789 d = FakeEvent(
751 id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
790 id="D",
791 sender=ALICE,
792 type=EventTypes.Member,
793 state_key="",
794 content={},
752795 ).to_event([c.event_id], [])
753796
754797 e = FakeEvent(
755 id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
798 id="E",
799 sender=ALICE,
800 type=EventTypes.Member,
801 state_key="",
802 content={},
756803 ).to_event([c.event_id, b.event_id], [])
757804
758805 persisted_events = {a.event_id: a, b.event_id: b}
9595 # No ignored_users key.
9696 self.get_success(
9797 self.store.add_account_data_for_user(
98 self.user, AccountDataTypes.IGNORED_USER_LIST, {},
98 self.user,
99 AccountDataTypes.IGNORED_USER_LIST,
100 {},
99101 )
100102 )
101103
6666 async def update(progress, count):
6767 self.assertEqual(progress, {"my_key": 2})
6868 self.assertAlmostEqual(
69 count, target_background_update_duration_ms / duration_ms, places=0,
69 count,
70 target_background_update_duration_ms / duration_ms,
71 places=0,
7072 )
7173 await self.updates._end_background_update("test_update")
7274 return count
4242 self.room_id = info["room_id"]
4343
4444 def run_background_update(self):
45 """Re run the background update to clean up the extremities.
46 """
45 """Re run the background update to clean up the extremities."""
4746 # Make sure we don't clash with in progress updates.
4847 self.assertTrue(
4948 self.store.db_pool.updates._all_done, "Background updates are still ongoing"
4040 device_id = "MY_DEVICE"
4141
4242 # Insert a user IP
43 self.get_success(self.store.store_device(user_id, device_id, "display name",))
43 self.get_success(
44 self.store.store_device(
45 user_id,
46 device_id,
47 "display name",
48 )
49 )
4450 self.get_success(
4551 self.store.insert_client_ip(
4652 user_id, "access_token", "ip", "user_agent", device_id
213219 device_id = "MY_DEVICE"
214220
215221 # Insert a user IP
216 self.get_success(self.store.store_device(user_id, device_id, "display name",))
222 self.get_success(
223 self.store.store_device(
224 user_id,
225 device_id,
226 "display name",
227 )
228 )
217229 self.get_success(
218230 self.store.insert_client_ip(
219231 user_id, "access_token", "ip", "user_agent", device_id
302314 device_id = "MY_DEVICE"
303315
304316 # Insert a user IP
305 self.get_success(self.store.store_device(user_id, device_id, "display name",))
317 self.get_success(
318 self.store.store_device(
319 user_id,
320 device_id,
321 "display name",
322 )
323 )
306324 self.get_success(
307325 self.store.insert_client_ip(
308326 user_id, "access_token", "ip", "user_agent", device_id
8989 "content": {"tag": "power"},
9090 },
9191 ).build(
92 prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
92 prev_event_ids=[],
93 auth_event_ids=[create.event_id, bob_join.event_id],
9394 )
9495 )
9596
225226
226227 self.assertFalse(
227228 link_map.exists_path_from(
228 chain_map[create.event_id], chain_map[event.event_id],
229 chain_map[create.event_id],
230 chain_map[event.event_id],
229231 ),
230232 )
231233
286288 "content": {"tag": "power"},
287289 },
288290 ).build(
289 prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
291 prev_event_ids=[],
292 auth_event_ids=[create.event_id, bob_join.event_id],
290293 )
291294 )
292295
372375 )
373376
374377 def persist(
375 self, events: List[EventBase],
378 self,
379 events: List[EventBase],
376380 ):
377381 """Persist the given events and check that the links generated match
378382 those given.
393397 persist_events_store._persist_event_auth_chain_txn(txn, events)
394398
395399 self.get_success(
396 persist_events_store.db_pool.runInteraction("_persist", _persist,)
400 persist_events_store.db_pool.runInteraction(
401 "_persist",
402 _persist,
403 )
397404 )
398405
399406 def fetch_chains(
446453
447454 class LinkMapTestCase(unittest.TestCase):
448455 def test_simple(self):
449 """Basic tests for the LinkMap.
450 """
456 """Basic tests for the LinkMap."""
451457 link_map = _LinkMap()
452458
453459 link_map.add_link((1, 1), (2, 1), new=False)
489495 self.requester = create_requester(self.user_id)
490496
491497 def _generate_room(self) -> Tuple[str, List[Set[str]]]:
492 """Insert a room without a chain cover index.
493 """
498 """Insert a room without a chain cover index."""
494499 room_id = self.helper.create_room_as(self.user_id, tok=self.token)
495500
496501 # Mark the room as not having a chain cover index
214214 ],
215215 )
216216
217 self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
217 self.get_success(
218 self.store.db_pool.runInteraction(
219 "insert",
220 insert_event,
221 )
222 )
218223
219224 # Now actually test that various combinations give the right result:
220225
369374 )
370375
371376 self.hs.datastores.persist_events._persist_event_auth_chain_txn(
372 txn, [FakeEvent("b", room_id, auth_graph["b"])],
377 txn,
378 [FakeEvent("b", room_id, auth_graph["b"])],
373379 )
374380
375381 self.store.db_pool.simple_update_txn(
379385 updatevalues={"has_auth_chain_index": True},
380386 )
381387
382 self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
388 self.get_success(
389 self.store.db_pool.runInteraction(
390 "insert",
391 insert_event,
392 )
393 )
383394
384395 # Now actually test that various combinations give the right result:
385396
8383
8484 yield defer.ensureDeferred(
8585 self.store.add_push_actions_to_staging(
86 event.event_id, {user_id: action}, False,
86 event.event_id,
87 {user_id: action},
88 False,
8789 )
8890 )
8991 yield defer.ensureDeferred(
6767 self.assert_extremities([self.remote_event_1.event_id])
6868
6969 def persist_event(self, event, state=None):
70 """Persist the event, with optional state
71 """
70 """Persist the event, with optional state"""
7271 context = self.get_success(
7372 self.state.compute_event_context(event, old_state=state)
7473 )
7574 self.get_success(self.persistence.persist_event(event, context))
7675
7776 def assert_extremities(self, expected_extremities):
78 """Assert the current extremities for the room
79 """
77 """Assert the current extremities for the room"""
8078 extremities = self.get_success(
8179 self.store.get_prev_events_for_room(self.room_id)
8280 )
8585
8686 def _insert(txn):
8787 txn.execute(
88 "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
88 "INSERT INTO foobar VALUES (?, ?)",
89 (
90 stream_id,
91 instance_name,
92 ),
8993 )
9094 txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
9195 txn.execute(
137141 self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
138142
139143 def test_out_of_order_finish(self):
140 """Test that IDs persisted out of order are correctly handled
141 """
144 """Test that IDs persisted out of order are correctly handled"""
142145
143146 # Prefill table with 7 rows written by 'master'
144147 self._insert_rows("master", 7)
245248 self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
246249
247250 def test_get_next_txn(self):
248 """Test that the `get_next_txn` function works correctly.
249 """
251 """Test that the `get_next_txn` function works correctly."""
250252
251253 # Prefill table with 7 rows written by 'master'
252254 self._insert_rows("master", 7)
385387 self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
386388
387389 def test_writer_config_change(self):
388 """Test that changing the writer config correctly works.
389 """
390 """Test that changing the writer config correctly works."""
390391
391392 self._insert_row_with_id("first", 3)
392393 self._insert_row_with_id("second", 5)
433434 self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
434435
435436 def test_sequence_consistency(self):
436 """Test that we error out if the table and sequence diverges.
437 """
437 """Test that we error out if the table and sequence diverges."""
438438
439439 # Prefill with some rows
440440 self._insert_row_with_id("master", 3)
451451
452452
453453 class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
454 """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
455 """
454 """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
456455
457456 if not USE_POSTGRES_FOR_TESTS:
458457 skip = "Requires Postgres"
493492 return self.get_success(self.db_pool.runWithConnection(_create))
494493
495494 def _insert_row(self, instance_name: str, stream_id: int):
496 """Insert one row as the given instance with given stream_id.
497 """
495 """Insert one row as the given instance with given stream_id."""
498496
499497 def _insert(txn):
500498 txn.execute(
501 "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
499 "INSERT INTO foobar VALUES (?, ?)",
500 (
501 stream_id,
502 instance_name,
503 ),
502504 )
503505 txn.execute(
504506 """
197197 # value, although it gets stored on the config object as mau_limits.
198198 @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
199199 def test_reap_monthly_active_users_reserved_users(self):
200 """ Tests that reaping correctly handles reaping where reserved users are
200 """Tests that reaping correctly handles reaping where reserved users are
201201 present"""
202202 threepids = self.hs.config.mau_limits_reserved_threepids
203203 initial_users = len(threepids)
298298 )
299299
300300 def test_redact_censor(self):
301 """Test that a redacted event gets censored in the DB after a month
302 """
301 """Test that a redacted event gets censored in the DB after a month"""
303302
304303 self.get_success(
305304 self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
369368 self.assert_dict({"content": {}}, json.loads(event_json))
370369
371370 def test_redact_redaction(self):
372 """Tests that we can redact a redaction and can fetch it again.
373 """
371 """Tests that we can redact a redaction and can fetch it again."""
374372
375373 self.get_success(
376374 self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
403401 )
404402
405403 def test_store_redacted_redaction(self):
406 """Tests that we can store a redacted redaction.
407 """
404 """Tests that we can store a redacted redaction."""
408405
409406 self.get_success(
410407 self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
5151 "creation_ts": 1000,
5252 "user_type": None,
5353 "deactivated": 0,
54 "shadow_banned": 0,
5455 },
5556 (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
5657 )
144145 try:
145146 yield defer.ensureDeferred(
146147 self.store.validate_threepid_session(
147 "fake_sid", "fake_client_secret", "fake_token", 0,
148 "fake_sid",
149 "fake_client_secret",
150 "fake_token",
151 0,
148152 )
149153 )
150154 except ThreepidValidationError as e:
157161 try:
158162 yield defer.ensureDeferred(
159163 self.store.validate_threepid_session(
160 "fake_sid", "fake_client_secret", "fake_token", 0,
164 "fake_sid",
165 "fake_client_secret",
166 "fake_token",
167 0,
161168 )
162169 )
163170 except ThreepidValidationError as e:
8484
8585 # king should be able to send state
8686 event_auth.check(
87 RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False,
87 RoomVersions.V1,
88 _random_state_event(king),
89 auth_events,
90 do_sig_check=False,
8891 )
8992
9093 def test_alias_event(self):
98101
99102 # creator should be able to send aliases
100103 event_auth.check(
101 RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False,
104 RoomVersions.V1,
105 _alias_event(creator),
106 auth_events,
107 do_sig_check=False,
102108 )
103109
104110 # Reject an event with no state key.
121127
122128 # Note that the member does *not* need to be in the room.
123129 event_auth.check(
124 RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False,
130 RoomVersions.V1,
131 _alias_event(other),
132 auth_events,
133 do_sig_check=False,
125134 )
126135
127136 def test_msc2432_alias_event(self):
135144
136145 # creator should be able to send aliases
137146 event_auth.check(
138 RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False,
147 RoomVersions.V6,
148 _alias_event(creator),
149 auth_events,
150 do_sig_check=False,
139151 )
140152
141153 # No particular checks are done on the state key.
155167 # Per standard auth rules, the member must be in the room.
156168 with self.assertRaises(AuthError):
157169 event_auth.check(
158 RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False,
170 RoomVersions.V6,
171 _alias_event(other),
172 auth_events,
173 do_sig_check=False,
159174 )
160175
161176 def test_msc2209(self):
241241 )
242242
243243 channel = self.make_request(
244 "POST", "/register", request_data, access_token=token,
244 "POST",
245 "/register",
246 request_data,
247 access_token=token,
245248 )
246249
247250 if channel.code != 200:
2020
2121
2222 def get_sample_labels_value(sample):
23 """ Extract the labels and values of a sample.
23 """Extract the labels and values of a sample.
2424
2525 prometheus_client 0.5 changed the sample type to a named tuple with more
2626 members than the plain tuple had in 0.4 and earlier. This function can
1414
1515 from synapse.rest.media.v1.preview_url_resource import (
1616 decode_and_calc_og,
17 get_html_media_encoding,
1718 summarize_paragraphs,
1819 )
1920
2526 lxml = None
2627
2728
28 class PreviewTestCase(unittest.TestCase):
29 class SummarizeTestCase(unittest.TestCase):
2930 if not lxml:
3031 skip = "url preview feature requires lxml"
3132
143144 )
144145
145146
146 class PreviewUrlTestCase(unittest.TestCase):
147 class CalcOgTestCase(unittest.TestCase):
147148 if not lxml:
148149 skip = "url preview feature requires lxml"
149150
150151 def test_simple(self):
151 html = """
152 html = b"""
152153 <html>
153154 <head><title>Foo</title></head>
154155 <body>
162163 self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
163164
164165 def test_comment(self):
165 html = """
166 html = b"""
166167 <html>
167168 <head><title>Foo</title></head>
168169 <body>
177178 self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
178179
179180 def test_comment2(self):
180 html = """
181 html = b"""
181182 <html>
182183 <head><title>Foo</title></head>
183184 <body>
201202 )
202203
203204 def test_script(self):
204 html = """
205 html = b"""
205206 <html>
206207 <head><title>Foo</title></head>
207208 <body>
216217 self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
217218
218219 def test_missing_title(self):
219 html = """
220 html = b"""
220221 <html>
221222 <body>
222223 Some text.
229230 self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
230231
231232 def test_h1_as_title(self):
232 html = """
233 html = b"""
233234 <html>
234235 <meta property="og:description" content="Some text."/>
235236 <body>
243244 self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
244245
245246 def test_missing_title_and_broken_h1(self):
246 html = """
247 html = b"""
247248 <html>
248249 <body>
249250 <h1><a href="foo"/></h1>
257258 self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
258259
259260 def test_empty(self):
260 html = ""
261 """Test a body with no data in it."""
262 html = b""
263 og = decode_and_calc_og(html, "http://example.com/test.html")
264 self.assertEqual(og, {})
265
266 def test_no_tree(self):
267 """A valid body with no tree in it."""
268 html = b"\x00"
261269 og = decode_and_calc_og(html, "http://example.com/test.html")
262270 self.assertEqual(og, {})
263271
264272 def test_invalid_encoding(self):
265273 """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
266 html = """
274 html = b"""
267275 <html>
268276 <head><title>Foo</title></head>
269277 <body>
289297 """
290298 og = decode_and_calc_og(html, "http://example.com/test.html")
291299 self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
300
301
302 class MediaEncodingTestCase(unittest.TestCase):
303 def test_meta_charset(self):
304 """A character encoding is found via the meta tag."""
305 encoding = get_html_media_encoding(
306 b"""
307 <html>
308 <head><meta charset="ascii">
309 </head>
310 </html>
311 """,
312 "text/html",
313 )
314 self.assertEqual(encoding, "ascii")
315
316 # A less well-formed version.
317 encoding = get_html_media_encoding(
318 b"""
319 <html>
320 <head>< meta charset = ascii>
321 </head>
322 </html>
323 """,
324 "text/html",
325 )
326 self.assertEqual(encoding, "ascii")
327
328 def test_xml_encoding(self):
329 """A character encoding is found via the meta tag."""
330 encoding = get_html_media_encoding(
331 b"""
332 <?xml version="1.0" encoding="ascii"?>
333 <html>
334 </html>
335 """,
336 "text/html",
337 )
338 self.assertEqual(encoding, "ascii")
339
340 def test_meta_xml_encoding(self):
341 """Meta tags take precedence over XML encoding."""
342 encoding = get_html_media_encoding(
343 b"""
344 <?xml version="1.0" encoding="ascii"?>
345 <html>
346 <head><meta charset="UTF-16">
347 </head>
348 </html>
349 """,
350 "text/html",
351 )
352 self.assertEqual(encoding, "UTF-16")
353
354 def test_content_type(self):
355 """A character encoding is found via the Content-Type header."""
356 # Test a few variations of the header.
357 headers = (
358 'text/html; charset="ascii";',
359 "text/html;charset=ascii;",
360 'text/html; charset="ascii"',
361 "text/html; charset=ascii",
362 'text/html; charset="ascii;',
363 'text/html; charset=ascii";',
364 )
365 for header in headers:
366 encoding = get_html_media_encoding(b"", header)
367 self.assertEqual(encoding, "ascii")
368
369 def test_fallback(self):
370 """A character encoding cannot be found in the body or header."""
371 encoding = get_html_media_encoding(b"", "text/html")
372 self.assertEqual(encoding, "utf-8")
165165
166166 res = JsonResource(self.homeserver)
167167 res.register_paths(
168 "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
168 "GET",
169 [re.compile("^/_matrix/foo$")],
170 _callback,
171 "test_servlet",
169172 )
170173
171174 # The path was registered as GET, but this is a HEAD request.
254254 # We need a valid token ID to satisfy foreign key constraints.
255255 token_id = self.get_success(
256256 self.hs.get_datastore().add_access_token_to_user(
257 self.helper.auth_user_id, "some_fake_token", None, None,
257 self.helper.auth_user_id,
258 "some_fake_token",
259 None,
260 None,
258261 )
259262 )
260263
0 # -*- coding: utf-8 -*-
1 # Copyright 2021 The Matrix.org Foundation C.I.C.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 from unittest.mock import Mock
15
16 from twisted.internet import defer
17 from twisted.internet.defer import Deferred
18
19 from synapse.util.caches.cached_call import CachedCall, RetryOnExceptionCachedCall
20
21 from tests.test_utils import get_awaitable_result
22 from tests.unittest import TestCase
23
24
25 class CachedCallTestCase(TestCase):
26 def test_get(self):
27 """
28 Happy-path test case: makes a couple of calls and makes sure they behave
29 correctly
30 """
31 d = Deferred()
32
33 async def f():
34 return await d
35
36 slow_call = Mock(side_effect=f)
37
38 cached_call = CachedCall(slow_call)
39
40 # the mock should not yet have been called
41 slow_call.assert_not_called()
42
43 # now fire off a couple of calls
44 completed_results = []
45
46 async def r():
47 res = await cached_call.get()
48 completed_results.append(res)
49
50 r1 = defer.ensureDeferred(r())
51 r2 = defer.ensureDeferred(r())
52
53 # neither result should be complete yet
54 self.assertNoResult(r1)
55 self.assertNoResult(r2)
56
57 # and the mock should have been called *once*, with no params
58 slow_call.assert_called_once_with()
59
60 # allow the deferred to complete, which should complete both the pending results
61 d.callback(123)
62 self.assertEqual(completed_results, [123, 123])
63 self.successResultOf(r1)
64 self.successResultOf(r2)
65
66 # another call to the getter should complete immediately
67 slow_call.reset_mock()
68 r3 = get_awaitable_result(cached_call.get())
69 self.assertEqual(r3, 123)
70 slow_call.assert_not_called()
71
72 def test_fast_call(self):
73 """
74 Test the behaviour when the underlying function completes immediately
75 """
76
77 async def f():
78 return 12
79
80 fast_call = Mock(side_effect=f)
81 cached_call = CachedCall(fast_call)
82
83 # the mock should not yet have been called
84 fast_call.assert_not_called()
85
86 # run the call a couple of times, which should complete immediately
87 self.assertEqual(get_awaitable_result(cached_call.get()), 12)
88 self.assertEqual(get_awaitable_result(cached_call.get()), 12)
89
90 # the mock should have been called once
91 fast_call.assert_called_once_with()
92
93
94 class RetryOnExceptionCachedCallTestCase(TestCase):
95 def test_get(self):
96 # set up the RetryOnExceptionCachedCall around a function which will fail
97 # (after a while)
98 d = Deferred()
99
100 async def f1():
101 await d
102 raise ValueError("moo")
103
104 slow_call = Mock(side_effect=f1)
105 cached_call = RetryOnExceptionCachedCall(slow_call)
106
107 # the mock should not yet have been called
108 slow_call.assert_not_called()
109
110 # now fire off a couple of calls
111 completed_results = []
112
113 async def r():
114 try:
115 await cached_call.get()
116 except Exception as e1:
117 completed_results.append(e1)
118
119 r1 = defer.ensureDeferred(r())
120 r2 = defer.ensureDeferred(r())
121
122 # neither result should be complete yet
123 self.assertNoResult(r1)
124 self.assertNoResult(r2)
125
126 # and the mock should have been called *once*, with no params
127 slow_call.assert_called_once_with()
128
129 # complete the deferred, which should make the pending calls fail
130 d.callback(0)
131 self.assertEqual(len(completed_results), 2)
132 for e in completed_results:
133 self.assertIsInstance(e, ValueError)
134 self.assertEqual(e.args, ("moo",))
135
136 # reset the mock to return a successful result, and make another pair of calls
137 # to the getter
138 d = Deferred()
139
140 async def f2():
141 return await d
142
143 slow_call.reset_mock()
144 slow_call.side_effect = f2
145 r3 = defer.ensureDeferred(cached_call.get())
146 r4 = defer.ensureDeferred(cached_call.get())
147
148 self.assertNoResult(r3)
149 self.assertNoResult(r4)
150 slow_call.assert_called_once_with()
151
152 # let that call complete, and check the results
153 d.callback(123)
154 self.assertEqual(self.successResultOf(r3), 123)
155 self.assertEqual(self.successResultOf(r4), 123)
156
157 # and now more calls to the getter should complete immediately
158 slow_call.reset_mock()
159 self.assertEqual(get_awaitable_result(cached_call.get()), 123)
160 slow_call.assert_not_called()
231231
232232 def test_eviction_iterable(self):
233233 cache = DeferredCache(
234 "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True,
234 "test",
235 max_entries=3,
236 apply_cache_factor_from_config=False,
237 iterable=True,
235238 )
236239
237240 cache.prefill(1, ["one", "two"])
142142 obj.mock.assert_not_called()
143143
144144 def test_cache_with_sync_exception(self):
145 """If the wrapped function throws synchronously, things should continue to work
146 """
145 """If the wrapped function throws synchronously, things should continue to work"""
147146
148147 class Cls:
149148 @cached()
164163 self.failureResultOf(d, SynapseError)
165164
166165 def test_cache_with_async_exception(self):
167 """The wrapped function returns a failure
168 """
166 """The wrapped function returns a failure"""
169167
170168 class Cls:
171169 result = None
281279 try:
282280 d = obj.fn(1)
283281 self.assertEqual(
284 current_context(), SENTINEL_CONTEXT,
282 current_context(),
283 SENTINEL_CONTEXT,
285284 )
286285 yield d
287286 self.fail("No exception thrown")
373372 obj.mock.assert_not_called()
374373
375374 def test_cache_iterable_with_sync_exception(self):
376 """If the wrapped function throws synchronously, things should continue to work
377 """
375 """If the wrapped function throws synchronously, things should continue to work"""
378376
379377 class Cls:
380378 @descriptors.cached(iterable=True)
2323 parts = chunk_seq("123", 8)
2424
2525 self.assertEqual(
26 list(parts), ["123"],
26 list(parts),
27 ["123"],
2728 )
2829
2930 def test_long_seq(self):
3031 parts = chunk_seq("abcdefghijklmnop", 8)
3132
3233 self.assertEqual(
33 list(parts), ["abcdefgh", "ijklmnop"],
34 list(parts),
35 ["abcdefgh", "ijklmnop"],
3436 )
3537
3638 def test_uneven_parts(self):
3739 parts = chunk_seq("abcdefghijklmnop", 5)
3840
3941 self.assertEqual(
40 list(parts), ["abcde", "fghij", "klmno", "p"],
42 list(parts),
43 ["abcde", "fghij", "klmno", "p"],
4144 )
4245
4346 def test_empty_input(self):
4447 parts = chunk_seq([], 5)
4548
4649 self.assertEqual(
47 list(parts), [],
50 list(parts),
51 [],
4852 )
4953
5054
6969 self.assertTrue("user@foo.com" not in cache._entity_to_key)
7070
7171 self.assertEqual(
72 cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"],
72 cache.get_all_entities_changed(2),
73 ["bar@baz.net", "user@elsewhere.org"],
7374 )
7475 self.assertIsNone(cache.get_all_entities_changed(1))
7576
7980 {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
8081 )
8182 self.assertEqual(
82 cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"],
83 cache.get_all_entities_changed(2),
84 ["user@elsewhere.org", "bar@baz.net"],
8385 )
8486 self.assertIsNone(cache.get_all_entities_changed(1))
8587
221223 # Query a subset of the entries mid-way through the stream. We should
222224 # only get back the subset.
223225 self.assertEqual(
224 cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"},
226 cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
227 {"bar@baz.net"},
225228 )
226229
227230 def test_max_pos(self):
262262 db_conn.close()
263263
264264 hs = homeserver_to_use(
265 name, config=config, version_string="Synapse/tests", reactor=reactor,
265 name,
266 config=config,
267 version_string="Synapse/tests",
268 reactor=reactor,
266269 )
267270
268271 # Install @cache_in_self attributes
364367 def trigger(
365368 self, http_method, path, content, mock_request, federation_auth_origin=None
366369 ):
367 """ Fire an HTTP event.
370 """Fire an HTTP event.
368371
369372 Args:
370373 http_method : The HTTP method
527530
528531
529532 async def create_room(hs, room_id: str, creator_id: str):
530 """Creates and persist a creation event for the given room
531 """
533 """Creates and persist a creation event for the given room"""
532534
533535 persistence_store = hs.get_storage().persistence
534536 store = hs.get_datastore()