Update upstream source from tag 'upstream/1.30.0'
Update to upstream version '1.30.0'
with Debian dir df09c74e0d9e271673eecd80d71bdccf3cb612e1
Andrej Shadura
3 years ago
0 | # Black reformatting (#5482). | |
1 | 32e7c9e7f20b57dd081023ac42d6931a8da9b3a3 | |
2 | ||
3 | # Target Python 3.5 with black (#8664). | |
4 | aff1eb7c671b0a3813407321d2702ec46c71fa56 | |
5 | ||
6 | # Update black to 20.8b1 (#9381). | |
7 | 0a00b7ff14890987f09112a2ae696c61001e6cf1 |
0 | Synapse 1.30.0 (2021-03-22) | |
1 | =========================== | |
2 | ||
3 | Note that this release deprecates the ability for appservices to | |
4 | call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice | |
5 | developers should use a `type` value of `m.login.application_service` as | |
6 | per [the spec](https://matrix.org/docs/spec/application_service/r0.1.2#server-admin-style-permissions). | |
7 | In future releases, calling this endpoint with an access token - but without a `m.login.application_service` | |
8 | type - will fail. | |
9 | ||
10 | ||
11 | No significant changes. | |
12 | ||
13 | ||
14 | Synapse 1.30.0rc1 (2021-03-16) | |
15 | ============================== | |
16 | ||
17 | Features | |
18 | -------- | |
19 | ||
20 | - Add prometheus metrics for number of users successfully registering and logging in. ([\#9510](https://github.com/matrix-org/synapse/issues/9510), [\#9511](https://github.com/matrix-org/synapse/issues/9511), [\#9573](https://github.com/matrix-org/synapse/issues/9573)) | |
21 | - Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers. ([\#9540](https://github.com/matrix-org/synapse/issues/9540)) | |
22 | - Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets. ([\#9549](https://github.com/matrix-org/synapse/issues/9549)) | |
23 | - Optimise handling of incomplete room history for incoming federation. ([\#9601](https://github.com/matrix-org/synapse/issues/9601)) | |
24 | - Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9617](https://github.com/matrix-org/synapse/issues/9617)) | |
25 | - Tell spam checker modules about the SSO IdP a user registered through if one was used. ([\#9626](https://github.com/matrix-org/synapse/issues/9626)) | |
26 | ||
27 | ||
28 | Bugfixes | |
29 | -------- | |
30 | ||
31 | - Fix long-standing bug when generating thumbnails for some images with transparency: `TypeError: cannot unpack non-iterable int object`. ([\#9473](https://github.com/matrix-org/synapse/issues/9473)) | |
32 | - Purge chain cover indexes for events that were purged prior to Synapse v1.29.0. ([\#9542](https://github.com/matrix-org/synapse/issues/9542), [\#9583](https://github.com/matrix-org/synapse/issues/9583)) | |
33 | - Fix bug where federation requests were not correctly retried on 5xx responses. ([\#9567](https://github.com/matrix-org/synapse/issues/9567)) | |
34 | - Fix re-activating an account via the admin API when local passwords are disabled. ([\#9587](https://github.com/matrix-org/synapse/issues/9587)) | |
35 | - Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages. ([\#9597](https://github.com/matrix-org/synapse/issues/9597)) | |
36 | - Fix a bug introduced in v1.28.0 where the OpenID Connect callback endpoint could error with a `MacaroonInitException`. ([\#9620](https://github.com/matrix-org/synapse/issues/9620)) | |
37 | - Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. ([\#9623](https://github.com/matrix-org/synapse/issues/9623)) | |
38 | ||
39 | ||
40 | Updates to the Docker image | |
41 | --------------------------- | |
42 | ||
43 | - Make use of an improved malloc implementation (`jemalloc`) in the docker image. ([\#8553](https://github.com/matrix-org/synapse/issues/8553)) | |
44 | ||
45 | ||
46 | Improved Documentation | |
47 | ---------------------- | |
48 | ||
49 | - Add relayd entry to reverse proxy example configurations. ([\#9508](https://github.com/matrix-org/synapse/issues/9508)) | |
50 | - Improve the SAML2 upgrade notes for 1.27.0. ([\#9550](https://github.com/matrix-org/synapse/issues/9550)) | |
51 | - Link to the "List user's media" admin API from the media admin API docs. ([\#9571](https://github.com/matrix-org/synapse/issues/9571)) | |
52 | - Clarify the spam checker modules documentation example to mention that `parse_config` is a required method. ([\#9580](https://github.com/matrix-org/synapse/issues/9580)) | |
53 | - Clarify the sample configuration for `stats` settings. ([\#9604](https://github.com/matrix-org/synapse/issues/9604)) | |
54 | ||
55 | ||
56 | Deprecations and Removals | |
57 | ------------------------- | |
58 | ||
59 | - The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`. ([\#9540](https://github.com/matrix-org/synapse/issues/9540)) | |
60 | - Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release. ([\#9559](https://github.com/matrix-org/synapse/issues/9559)) | |
61 | ||
62 | ||
63 | Internal Changes | |
64 | ---------------- | |
65 | ||
66 | - Add tests to ResponseCache. ([\#9458](https://github.com/matrix-org/synapse/issues/9458)) | |
67 | - Add type hints to purge room and server notice admin API. ([\#9520](https://github.com/matrix-org/synapse/issues/9520)) | |
68 | - Add extra logging to ObservableDeferred when callbacks throw exceptions. ([\#9523](https://github.com/matrix-org/synapse/issues/9523)) | |
69 | - Fix incorrect type hints. ([\#9528](https://github.com/matrix-org/synapse/issues/9528), [\#9543](https://github.com/matrix-org/synapse/issues/9543), [\#9591](https://github.com/matrix-org/synapse/issues/9591), [\#9608](https://github.com/matrix-org/synapse/issues/9608), [\#9618](https://github.com/matrix-org/synapse/issues/9618)) | |
70 | - Add an additional test for purging a room. ([\#9541](https://github.com/matrix-org/synapse/issues/9541)) | |
71 | - Add a `.git-blame-ignore-revs` file with the hashes of auto-formatting. ([\#9560](https://github.com/matrix-org/synapse/issues/9560)) | |
72 | - Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. ([\#9561](https://github.com/matrix-org/synapse/issues/9561)) | |
73 | - Fix spurious errors reported by the `config-lint.sh` script. ([\#9562](https://github.com/matrix-org/synapse/issues/9562)) | |
74 | - Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper. ([\#9563](https://github.com/matrix-org/synapse/issues/9563)) | |
75 | - Do not have mypy ignore type hints from unpaddedbase64. ([\#9568](https://github.com/matrix-org/synapse/issues/9568)) | |
76 | - Improve efficiency of calculating the auth chain in large rooms. ([\#9576](https://github.com/matrix-org/synapse/issues/9576)) | |
77 | - Convert `synapse.types.Requester` to an `attrs` class. ([\#9586](https://github.com/matrix-org/synapse/issues/9586)) | |
78 | - Add logging for redis connection setup. ([\#9590](https://github.com/matrix-org/synapse/issues/9590)) | |
79 | - Improve logging when processing incoming transactions. ([\#9596](https://github.com/matrix-org/synapse/issues/9596)) | |
80 | - Remove unused `stats.retention` setting, and emit a warning if stats are disabled. ([\#9604](https://github.com/matrix-org/synapse/issues/9604)) | |
81 | - Prevent attempting to bundle aggregations for state events in /context APIs. ([\#9619](https://github.com/matrix-org/synapse/issues/9619)) | |
82 | ||
83 | ||
0 | 84 | Synapse 1.29.0 (2021-03-08) |
1 | 85 | =========================== |
2 | 86 |
19 | 19 | recursive-include scripts-dev * |
20 | 20 | recursive-include synapse *.pyi |
21 | 21 | recursive-include tests *.py |
22 | include tests/http/ca.crt | |
23 | include tests/http/ca.key | |
24 | include tests/http/server.key | |
22 | recursive-include tests *.pem | |
23 | recursive-include tests *.p8 | |
24 | recursive-include tests *.crt | |
25 | recursive-include tests *.key | |
25 | 26 | |
26 | 27 | recursive-include synapse/res * |
27 | 28 | recursive-include synapse/static *.css |
182 | 182 | It is recommended to put a reverse proxy such as |
183 | 183 | `nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_, |
184 | 184 | `Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_, |
185 | `Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or | |
186 | `HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of | |
185 | `Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_, | |
186 | `HAProxy <https://www.haproxy.org/>`_ or | |
187 | `relayd <https://man.openbsd.org/relayd.8>`_ in front of Synapse. One advantage of | |
187 | 188 | doing so is that it means that you can expose the default https port (443) to |
188 | 189 | Matrix clients without needing to run Synapse with root privileges. |
189 | 190 |
122 | 122 | * If your server is configured for single sign-on via a SAML2 identity provider, you will |
123 | 123 | need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted |
124 | 124 | "ACS location" (also known as "allowed callback URLs") at the identity provider. |
125 | ||
126 | The "Issuer" in the "AuthnRequest" to the SAML2 identity provider is also updated to | |
127 | ``[synapse public baseurl]/_synapse/client/saml2/metadata.xml``. If your SAML2 identity | |
128 | provider uses this property to validate or otherwise identify Synapse, its configuration | |
129 | will need to be updated to use the new URL. Alternatively you could create a new, separate | |
130 | "EntityDescriptor" in your SAML2 identity provider with the new URLs and leave the URLs in | |
131 | the existing "EntityDescriptor" as they were. | |
125 | 132 | |
126 | 133 | Changes to HTML templates |
127 | 134 | ------------------------- |
68 | 68 | libpq5 \ |
69 | 69 | libwebp6 \ |
70 | 70 | xmlsec1 \ |
71 | libjemalloc2 \ | |
71 | 72 | && rm -rf /var/lib/apt/lists/* |
72 | 73 | |
73 | 74 | COPY --from=builder /install /usr/local |
203 | 203 | timeout: 10s |
204 | 204 | retries: 3 |
205 | 205 | ``` |
206 | ||
207 | ## Using jemalloc | |
208 | ||
209 | Jemalloc is embedded in the image and will be used instead of the default allocator. | |
210 | You can read about jemalloc by reading the Synapse [README](../README.md)⏎ |
2 | 2 | import codecs |
3 | 3 | import glob |
4 | 4 | import os |
5 | import platform | |
5 | 6 | import subprocess |
6 | 7 | import sys |
7 | 8 | |
212 | 213 | if "-m" not in args: |
213 | 214 | args = ["-m", synapse_worker] + args |
214 | 215 | |
216 | jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),) | |
217 | ||
218 | if os.path.isfile(jemallocpath): | |
219 | environ["LD_PRELOAD"] = jemallocpath | |
220 | else: | |
221 | log("Could not find %s, will not use" % (jemallocpath,)) | |
222 | ||
215 | 223 | # if there are no config files passed to synapse, try adding the default file |
216 | 224 | if not any(p.startswith("--config-path") or p.startswith("-c") for p in args): |
217 | 225 | config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data") |
247 | 255 | args = ["python"] + args |
248 | 256 | if ownership is not None: |
249 | 257 | args = ["gosu", ownership] + args |
250 | os.execv("/usr/sbin/gosu", args) | |
251 | else: | |
252 | os.execv("/usr/local/bin/python", args) | |
258 | os.execve("/usr/sbin/gosu", args, environ) | |
259 | else: | |
260 | os.execve("/usr/local/bin/python", args, environ) | |
253 | 261 | |
254 | 262 | |
255 | 263 | if __name__ == "__main__": |
0 | 0 | # Contents |
1 | - [List all media in a room](#list-all-media-in-a-room) | |
1 | - [Querying media](#querying-media) | |
2 | * [List all media in a room](#list-all-media-in-a-room) | |
3 | * [List all media uploaded by a user](#list-all-media-uploaded-by-a-user) | |
2 | 4 | - [Quarantine media](#quarantine-media) |
3 | 5 | * [Quarantining media by ID](#quarantining-media-by-id) |
4 | 6 | * [Quarantining media in a room](#quarantining-media-in-a-room) |
9 | 11 | * [Delete local media by date or size](#delete-local-media-by-date-or-size) |
10 | 12 | - [Purge Remote Media API](#purge-remote-media-api) |
11 | 13 | |
12 | # List all media in a room | |
14 | # Querying media | |
15 | ||
16 | These APIs allow extracting media information from the homeserver. | |
17 | ||
18 | ## List all media in a room | |
13 | 19 | |
14 | 20 | This API gets a list of known media in a room. |
15 | 21 | However, it only shows media from unencrypted events or rooms. |
34 | 40 | ] |
35 | 41 | } |
36 | 42 | ``` |
43 | ||
44 | ## List all media uploaded by a user | |
45 | ||
46 | Listing all media that has been uploaded by a local user can be achieved through | |
47 | the use of the [List media of a user](user_admin_api.rst#list-media-of-a-user) | |
48 | Admin API. | |
37 | 49 | |
38 | 50 | # Quarantine media |
39 | 51 |
225 | 225 | oidc_providers: |
226 | 226 | - idp_id: github |
227 | 227 | idp_name: Github |
228 | idp_brand: "org.matrix.github" # optional: styling hint for clients | |
228 | idp_brand: "github" # optional: styling hint for clients | |
229 | 229 | discover: false |
230 | 230 | issuer: "https://github.com/" |
231 | 231 | client_id: "your-client-id" # TO BE FILLED |
251 | 251 | oidc_providers: |
252 | 252 | - idp_id: google |
253 | 253 | idp_name: Google |
254 | idp_brand: "org.matrix.google" # optional: styling hint for clients | |
254 | idp_brand: "google" # optional: styling hint for clients | |
255 | 255 | issuer: "https://accounts.google.com/" |
256 | 256 | client_id: "your-client-id" # TO BE FILLED |
257 | 257 | client_secret: "your-client-secret" # TO BE FILLED |
298 | 298 | oidc_providers: |
299 | 299 | - idp_id: gitlab |
300 | 300 | idp_name: Gitlab |
301 | idp_brand: "org.matrix.gitlab" # optional: styling hint for clients | |
301 | idp_brand: "gitlab" # optional: styling hint for clients | |
302 | 302 | issuer: "https://gitlab.com/" |
303 | 303 | client_id: "your-client-id" # TO BE FILLED |
304 | 304 | client_secret: "your-client-secret" # TO BE FILLED |
333 | 333 | ```yaml |
334 | 334 | - idp_id: facebook |
335 | 335 | idp_name: Facebook |
336 | idp_brand: "org.matrix.facebook" # optional: styling hint for clients | |
336 | idp_brand: "facebook" # optional: styling hint for clients | |
337 | 337 | discover: false |
338 | 338 | issuer: "https://facebook.com" |
339 | 339 | client_id: "your-client-id" # TO BE FILLED |
385 | 385 | config: |
386 | 386 | subject_claim: "id" |
387 | 387 | localpart_template: "{{ user.login }}" |
388 | display_name_template: "{{ user.full_name }}" | |
388 | display_name_template: "{{ user.full_name }}" | |
389 | 389 | ``` |
390 | 390 | |
391 | 391 | ### XWiki |
400 | 400 | idp_name: "XWiki" |
401 | 401 | issuer: "https://myxwikihost/xwiki/oidc/" |
402 | 402 | client_id: "your-client-id" # TO BE FILLED |
403 | # Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed | |
404 | client_secret: "dontcare" | |
403 | client_auth_method: none | |
405 | 404 | scopes: ["openid", "profile"] |
406 | 405 | user_profile_method: "userinfo_endpoint" |
407 | 406 | user_mapping_provider: |
409 | 408 | localpart_template: "{{ user.preferred_username }}" |
410 | 409 | display_name_template: "{{ user.name }}" |
411 | 410 | ``` |
411 | ||
412 | ## Apple | |
413 | ||
414 | Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account. | |
415 | ||
416 | You will need to create a new "Services ID" for SiWA, and create and download a | |
417 | private key with "SiWA" enabled. | |
418 | ||
419 | As well as the private key file, you will need: | |
420 | * Client ID: the "identifier" you gave the "Services ID" | |
421 | * Team ID: a 10-character ID associated with your developer account. | |
422 | * Key ID: the 10-character identifier for the key. | |
423 | ||
424 | https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more | |
425 | documentation on setting up SiWA. | |
426 | ||
427 | The synapse config will look like this: | |
428 | ||
429 | ```yaml | |
430 | - idp_id: apple | |
431 | idp_name: Apple | |
432 | issuer: "https://appleid.apple.com" | |
433 | client_id: "your-client-id" # Set to the "identifier" for your "ServicesID" | |
434 | client_auth_method: "client_secret_post" | |
435 | client_secret_jwt_key: | |
436 | key_file: "/path/to/AuthKey_KEYIDCODE.p8" # point to your key file | |
437 | jwt_header: | |
438 | alg: ES256 | |
439 | kid: "KEYIDCODE" # Set to the 10-char Key ID | |
440 | jwt_payload: | |
441 | iss: TEAMIDCODE # Set to the 10-char Team ID | |
442 | scopes: ["name", "email", "openid"] | |
443 | authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post | |
444 | user_mapping_provider: | |
445 | config: | |
446 | email_template: "{{ user.email }}" | |
447 | ``` |
2 | 2 | It is recommended to put a reverse proxy such as |
3 | 3 | [nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html), |
4 | 4 | [Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html), |
5 | [Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or | |
6 | [HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage | |
5 | [Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy), | |
6 | [HAProxy](https://www.haproxy.org/) or | |
7 | [relayd](https://man.openbsd.org/relayd.8) in front of Synapse. One advantage | |
7 | 8 | of doing so is that it means that you can expose the default https port |
8 | 9 | (443) to Matrix clients without needing to run Synapse with root |
9 | 10 | privileges. |
161 | 162 | server matrix 127.0.0.1:8008 |
162 | 163 | ``` |
163 | 164 | |
165 | ### Relayd | |
166 | ||
167 | ``` | |
168 | table <webserver> { 127.0.0.1 } | |
169 | table <matrixserver> { 127.0.0.1 } | |
170 | ||
171 | http protocol "https" { | |
172 | tls { no tlsv1.0, ciphers "HIGH" } | |
173 | tls keypair "example.com" | |
174 | match header set "X-Forwarded-For" value "$REMOTE_ADDR" | |
175 | match header set "X-Forwarded-Proto" value "https" | |
176 | ||
177 | # set CORS header for .well-known/matrix/server, .well-known/matrix/client | |
178 | # httpd does not support setting headers, so do it here | |
179 | match request path "/.well-known/matrix/*" tag "matrix-cors" | |
180 | match response tagged "matrix-cors" header set "Access-Control-Allow-Origin" value "*" | |
181 | ||
182 | pass quick path "/_matrix/*" forward to <matrixserver> | |
183 | pass quick path "/_synapse/client/*" forward to <matrixserver> | |
184 | ||
185 | # pass on non-matrix traffic to webserver | |
186 | pass forward to <webserver> | |
187 | } | |
188 | ||
189 | relay "https_traffic" { | |
190 | listen on egress port 443 tls | |
191 | protocol "https" | |
192 | forward to <matrixserver> port 8008 check tcp | |
193 | forward to <webserver> port 8080 check tcp | |
194 | } | |
195 | ||
196 | http protocol "matrix" { | |
197 | tls { no tlsv1.0, ciphers "HIGH" } | |
198 | tls keypair "example.com" | |
199 | block | |
200 | pass quick path "/_matrix/*" forward to <matrixserver> | |
201 | pass quick path "/_synapse/client/*" forward to <matrixserver> | |
202 | } | |
203 | ||
204 | relay "matrix_federation" { | |
205 | listen on egress port 8448 tls | |
206 | protocol "matrix" | |
207 | forward to <matrixserver> port 8008 check tcp | |
208 | } | |
209 | ``` | |
210 | ||
164 | 211 | ## Homeserver Configuration |
165 | 212 | |
166 | 213 | You will also want to set `bind_addresses: ['127.0.0.1']` and |
88 | 88 | # Whether to require authentication to retrieve profile data (avatars, |
89 | 89 | # display names) of other users through the client API. Defaults to |
90 | 90 | # 'false'. Note that profile data is also available via the federation |
91 | # API, so this setting is of limited value if federation is enabled on | |
92 | # the server. | |
91 | # API, unless allow_profile_lookup_over_federation is set to false. | |
93 | 92 | # |
94 | 93 | #require_auth_for_profile_requests: true |
95 | 94 | |
1779 | 1778 | # |
1780 | 1779 | # client_id: Required. oauth2 client id to use. |
1781 | 1780 | # |
1782 | # client_secret: Required. oauth2 client secret to use. | |
1781 | # client_secret: oauth2 client secret to use. May be omitted if | |
1782 | # client_secret_jwt_key is given, or if client_auth_method is 'none'. | |
1783 | # | |
1784 | # client_secret_jwt_key: Alternative to client_secret: details of a key used | |
1785 | # to create a JSON Web Token to be used as an OAuth2 client secret. If | |
1786 | # given, must be a dictionary with the following properties: | |
1787 | # | |
1788 | # key: a pem-encoded signing key. Must be a suitable key for the | |
1789 | # algorithm specified. Required unless 'key_file' is given. | |
1790 | # | |
1791 | # key_file: the path to file containing a pem-encoded signing key file. | |
1792 | # Required unless 'key' is given. | |
1793 | # | |
1794 | # jwt_header: a dictionary giving properties to include in the JWT | |
1795 | # header. Must include the key 'alg', giving the algorithm used to | |
1796 | # sign the JWT, such as "ES256", using the JWA identifiers in | |
1797 | # RFC7518. | |
1798 | # | |
1799 | # jwt_payload: an optional dictionary giving properties to include in | |
1800 | # the JWT payload. Normally this should include an 'iss' key. | |
1783 | 1801 | # |
1784 | 1802 | # client_auth_method: auth method to use when exchanging the token. Valid |
1785 | 1803 | # values are 'client_secret_basic' (default), 'client_secret_post' and |
1900 | 1918 | # |
1901 | 1919 | #- idp_id: github |
1902 | 1920 | # idp_name: Github |
1903 | # idp_brand: org.matrix.github | |
1921 | # idp_brand: github | |
1904 | 1922 | # discover: false |
1905 | 1923 | # issuer: "https://github.com/" |
1906 | 1924 | # client_id: "your-client-id" # TO BE FILLED |
2626 | 2644 | |
2627 | 2645 | |
2628 | 2646 | |
2629 | # Local statistics collection. Used in populating the room directory. | |
2630 | # | |
2631 | # 'bucket_size' controls how large each statistics timeslice is. It can | |
2632 | # be defined in a human readable short form -- e.g. "1d", "1y". | |
2633 | # | |
2634 | # 'retention' controls how long historical statistics will be kept for. | |
2635 | # It can be defined in a human readable short form -- e.g. "1d", "1y". | |
2636 | # | |
2637 | # | |
2638 | #stats: | |
2639 | # enabled: true | |
2640 | # bucket_size: 1d | |
2641 | # retention: 1y | |
2647 | # Settings for local room and user statistics collection. See | |
2648 | # docs/room_and_user_statistics.md. | |
2649 | # | |
2650 | stats: | |
2651 | # Uncomment the following to disable room and user statistics. Note that doing | |
2652 | # so may cause certain features (such as the room directory) not to work | |
2653 | # correctly. | |
2654 | # | |
2655 | #enabled: false | |
2656 | ||
2657 | # The size of each timeslice in the room_stats_historical and | |
2658 | # user_stats_historical tables, as a time period. Defaults to "1d". | |
2659 | # | |
2660 | #bucket_size: 1h | |
2642 | 2661 | |
2643 | 2662 | |
2644 | 2663 | # Server Notices room configuration |
13 | 13 | * An instance of `synapse.module_api.ModuleApi`. |
14 | 14 | |
15 | 15 | It then implements methods which return a boolean to alter behavior in Synapse. |
16 | All the methods must be defined. | |
16 | 17 | |
17 | 18 | There's a generic method for checking every event (`check_event_for_spam`), as |
18 | 19 | well as some specific methods: |
23 | 24 | * `user_may_publish_room` |
24 | 25 | * `check_username_for_spam` |
25 | 26 | * `check_registration_for_spam` |
27 | * `check_media_file_for_spam` | |
26 | 28 | |
27 | 29 | The details of each of these methods (as well as their inputs and outputs) |
28 | 30 | are documented in the `synapse.events.spamcheck.SpamChecker` class. |
29 | 31 | |
30 | 32 | The `ModuleApi` class provides a way for the custom spam checker class to |
31 | 33 | call back into the homeserver internals. |
34 | ||
35 | Additionally, a `parse_config` method is mandatory and receives the plugin config | |
36 | dictionary. After parsing, It must return an object which will be | |
37 | passed to `__init__` later. | |
32 | 38 | |
33 | 39 | ### Example |
34 | 40 | |
40 | 46 | self.config = config |
41 | 47 | self.api = api |
42 | 48 | |
49 | @staticmethod | |
50 | def parse_config(config): | |
51 | return config | |
52 | ||
43 | 53 | async def check_event_for_spam(self, foo): |
44 | 54 | return False # allow all events |
45 | 55 | |
58 | 68 | async def check_username_for_spam(self, user_profile): |
59 | 69 | return False # allow all usernames |
60 | 70 | |
61 | async def check_registration_for_spam(self, email_threepid, username, request_info): | |
71 | async def check_registration_for_spam( | |
72 | self, | |
73 | email_threepid, | |
74 | username, | |
75 | request_info, | |
76 | auth_provider_id, | |
77 | ): | |
62 | 78 | return RegistrationBehaviour.ALLOW # allow all registrations |
63 | 79 | |
64 | 80 | async def check_media_file_for_spam(self, file_wrapper, file_info): |
68 | 68 | synapse/util/async_helpers.py, |
69 | 69 | synapse/util/caches, |
70 | 70 | synapse/util/metrics.py, |
71 | synapse/util/macaroons.py, | |
71 | 72 | synapse/util/stringutils.py, |
72 | 73 | tests/replication, |
73 | 74 | tests/test_utils, |
113 | 114 | ignore_missing_imports = True |
114 | 115 | |
115 | 116 | [mypy-saml2.*] |
116 | ignore_missing_imports = True | |
117 | ||
118 | [mypy-unpaddedbase64] | |
119 | 117 | ignore_missing_imports = True |
120 | 118 | |
121 | 119 | [mypy-canonicaljson] |
1 | 1 | # Find linting errors in Synapse's default config file. |
2 | 2 | # Exits with 0 if there are no problems, or another code otherwise. |
3 | 3 | |
4 | # cd to the root of the repository | |
5 | cd `dirname $0`/.. | |
6 | ||
7 | # Restore backup of sample config upon script exit | |
8 | trap "mv docs/sample_config.yaml.bak docs/sample_config.yaml" EXIT | |
9 | ||
4 | 10 | # Fix non-lowercase true/false values |
5 | 11 | sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml |
6 | rm docs/sample_config.yaml.bak | |
7 | 12 | |
8 | 13 | # Check if anything changed |
9 | git diff --exit-code docs/sample_config.yaml | |
14 | diff docs/sample_config.yaml docs/sample_config.yaml.bak |
16 | 16 | """ |
17 | 17 | from typing import Any, List, Optional, Type, Union |
18 | 18 | |
19 | class RedisProtocol: | |
19 | from twisted.internet import protocol | |
20 | ||
21 | class RedisProtocol(protocol.Protocol): | |
20 | 22 | def publish(self, channel: str, message: bytes): ... |
21 | 23 | async def ping(self) -> None: ... |
22 | 24 | async def set( |
51 | 53 | |
52 | 54 | class ConnectionHandler: ... |
53 | 55 | |
54 | class RedisFactory: | |
56 | class RedisFactory(protocol.ReconnectingClientFactory): | |
55 | 57 | continueTrying: bool |
56 | 58 | handler: RedisProtocol |
57 | 59 | pool: List[RedisProtocol] |
47 | 47 | except ImportError: |
48 | 48 | pass |
49 | 49 | |
50 | __version__ = "1.29.0" | |
50 | __version__ = "1.30.0" | |
51 | 51 | |
52 | 52 | if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): |
53 | 53 | # We import here so that we don't have to install a bunch of deps when |
38 | 38 | from synapse.storage.databases.main.registration import TokenLookupResult |
39 | 39 | from synapse.types import StateMap, UserID |
40 | 40 | from synapse.util.caches.lrucache import LruCache |
41 | from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry | |
41 | 42 | from synapse.util.metrics import Measure |
42 | 43 | |
43 | 44 | logger = logging.getLogger(__name__) |
162 | 163 | |
163 | 164 | async def get_user_by_req( |
164 | 165 | self, |
165 | request: Request, | |
166 | request: SynapseRequest, | |
166 | 167 | allow_guest: bool = False, |
167 | 168 | rights: str = "access", |
168 | 169 | allow_expired: bool = False, |
407 | 408 | raise _InvalidMacaroonException() |
408 | 409 | |
409 | 410 | try: |
410 | user_id = self.get_user_id_from_macaroon(macaroon) | |
411 | user_id = get_value_from_macaroon(macaroon, "user_id") | |
411 | 412 | |
412 | 413 | guest = False |
413 | 414 | for caveat in macaroon.caveats: |
415 | 416 | guest = True |
416 | 417 | |
417 | 418 | self.validate_macaroon(macaroon, rights, user_id=user_id) |
418 | except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): | |
419 | except ( | |
420 | pymacaroons.exceptions.MacaroonException, | |
421 | KeyError, | |
422 | TypeError, | |
423 | ValueError, | |
424 | ): | |
419 | 425 | raise InvalidClientTokenError("Invalid macaroon passed.") |
420 | 426 | |
421 | 427 | if rights == "access": |
422 | 428 | self.token_cache[token] = (user_id, guest) |
423 | 429 | |
424 | 430 | return user_id, guest |
425 | ||
426 | def get_user_id_from_macaroon(self, macaroon): | |
427 | """Retrieve the user_id given by the caveats on the macaroon. | |
428 | ||
429 | Does *not* validate the macaroon. | |
430 | ||
431 | Args: | |
432 | macaroon (pymacaroons.Macaroon): The macaroon to validate | |
433 | ||
434 | Returns: | |
435 | (str) user id | |
436 | ||
437 | Raises: | |
438 | InvalidClientCredentialsError if there is no user_id caveat in the | |
439 | macaroon | |
440 | """ | |
441 | user_prefix = "user_id = " | |
442 | for caveat in macaroon.caveats: | |
443 | if caveat.caveat_id.startswith(user_prefix): | |
444 | return caveat.caveat_id[len(user_prefix) :] | |
445 | raise InvalidClientTokenError("No user caveat in macaroon") | |
446 | 431 | |
447 | 432 | def validate_macaroon(self, macaroon, type_string, user_id): |
448 | 433 | """ |
464 | 449 | v.satisfy_exact("type = " + type_string) |
465 | 450 | v.satisfy_exact("user_id = %s" % user_id) |
466 | 451 | v.satisfy_exact("guest = true") |
467 | v.satisfy_general(self._verify_expiry) | |
452 | satisfy_expiry(v, self.clock.time_msec) | |
468 | 453 | |
469 | 454 | # access_tokens include a nonce for uniqueness: any value is acceptable |
470 | 455 | v.satisfy_general(lambda c: c.startswith("nonce = ")) |
471 | 456 | |
472 | 457 | v.verify(macaroon, self._macaroon_secret_key) |
473 | ||
474 | def _verify_expiry(self, caveat): | |
475 | prefix = "time < " | |
476 | if not caveat.startswith(prefix): | |
477 | return False | |
478 | expiry = int(caveat[len(prefix) :]) | |
479 | now = self.hs.get_clock().time_msec() | |
480 | return now < expiry | |
481 | 458 | |
482 | 459 | def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService: |
483 | 460 | token = self.get_access_token_from_request(request) |
89 | 89 | self.clock = hs.get_clock() |
90 | 90 | |
91 | 91 | self.protocol_meta_cache = ResponseCache( |
92 | hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS | |
92 | hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS | |
93 | 93 | ) # type: ResponseCache[Tuple[str, str]] |
94 | 94 | |
95 | 95 | async def query_user(self, service, user_id): |
211 | 211 | |
212 | 212 | @classmethod |
213 | 213 | def read_file(cls, file_path, config_name): |
214 | cls.check_file(file_path, config_name) | |
215 | with open(file_path) as file_stream: | |
216 | return file_stream.read() | |
214 | """Deprecated: call read_file directly""" | |
215 | return read_file(file_path, (config_name,)) | |
217 | 216 | |
218 | 217 | def read_template(self, filename: str) -> jinja2.Template: |
219 | 218 | """Load a template file from disk. |
893 | 892 | return self._get_instance(key) |
894 | 893 | |
895 | 894 | |
896 | __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] | |
895 | def read_file(file_path: Any, config_path: Iterable[str]) -> str: | |
896 | """Check the given file exists, and read it into a string | |
897 | ||
898 | If it does not, emit an error indicating the problem | |
899 | ||
900 | Args: | |
901 | file_path: the file to be read | |
902 | config_path: where in the configuration file_path came from, so that a useful | |
903 | error can be emitted if it does not exist. | |
904 | Returns: | |
905 | content of the file. | |
906 | Raises: | |
907 | ConfigError if there is a problem reading the file. | |
908 | """ | |
909 | if not isinstance(file_path, str): | |
910 | raise ConfigError("%r is not a string", config_path) | |
911 | ||
912 | try: | |
913 | os.stat(file_path) | |
914 | with open(file_path) as file_stream: | |
915 | return file_stream.read() | |
916 | except OSError as e: | |
917 | raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e | |
918 | ||
919 | ||
920 | __all__ = [ | |
921 | "Config", | |
922 | "RootConfig", | |
923 | "ShardedWorkerHandlingConfig", | |
924 | "RoutableShardedWorkerHandlingConfig", | |
925 | "read_file", | |
926 | ] |
151 | 151 | |
152 | 152 | class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): |
153 | 153 | def get_instance(self, key: str) -> str: ... |
154 | ||
155 | def read_file(file_path: Any, config_path: Iterable[str]) -> str: ... |
20 | 20 | from string import Template |
21 | 21 | |
22 | 22 | import yaml |
23 | from zope.interface import implementer | |
23 | 24 | |
24 | 25 | from twisted.logger import ( |
26 | ILogObserver, | |
25 | 27 | LogBeginner, |
26 | 28 | STDLibLogObserver, |
27 | 29 | eventAsText, |
226 | 228 | |
227 | 229 | threadlocal = threading.local() |
228 | 230 | |
229 | def _log(event): | |
231 | @implementer(ILogObserver) | |
232 | def _log(event: dict) -> None: | |
230 | 233 | if "log_text" in event: |
231 | 234 | if event["log_text"].startswith("DNSDatagramProtocol starting on "): |
232 | 235 | return |
14 | 14 | # limitations under the License. |
15 | 15 | |
16 | 16 | from collections import Counter |
17 | from typing import Iterable, Optional, Tuple, Type | |
17 | from typing import Iterable, Mapping, Optional, Tuple, Type | |
18 | 18 | |
19 | 19 | import attr |
20 | 20 | |
24 | 24 | from synapse.util.module_loader import load_module |
25 | 25 | from synapse.util.stringutils import parse_and_validate_mxc_uri |
26 | 26 | |
27 | from ._base import Config, ConfigError | |
27 | from ._base import Config, ConfigError, read_file | |
28 | 28 | |
29 | 29 | DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider" |
30 | 30 | |
96 | 96 | # |
97 | 97 | # client_id: Required. oauth2 client id to use. |
98 | 98 | # |
99 | # client_secret: Required. oauth2 client secret to use. | |
99 | # client_secret: oauth2 client secret to use. May be omitted if | |
100 | # client_secret_jwt_key is given, or if client_auth_method is 'none'. | |
101 | # | |
102 | # client_secret_jwt_key: Alternative to client_secret: details of a key used | |
103 | # to create a JSON Web Token to be used as an OAuth2 client secret. If | |
104 | # given, must be a dictionary with the following properties: | |
105 | # | |
106 | # key: a pem-encoded signing key. Must be a suitable key for the | |
107 | # algorithm specified. Required unless 'key_file' is given. | |
108 | # | |
109 | # key_file: the path to file containing a pem-encoded signing key file. | |
110 | # Required unless 'key' is given. | |
111 | # | |
112 | # jwt_header: a dictionary giving properties to include in the JWT | |
113 | # header. Must include the key 'alg', giving the algorithm used to | |
114 | # sign the JWT, such as "ES256", using the JWA identifiers in | |
115 | # RFC7518. | |
116 | # | |
117 | # jwt_payload: an optional dictionary giving properties to include in | |
118 | # the JWT payload. Normally this should include an 'iss' key. | |
100 | 119 | # |
101 | 120 | # client_auth_method: auth method to use when exchanging the token. Valid |
102 | 121 | # values are 'client_secret_basic' (default), 'client_secret_post' and |
217 | 236 | # |
218 | 237 | #- idp_id: github |
219 | 238 | # idp_name: Github |
220 | # idp_brand: org.matrix.github | |
239 | # idp_brand: github | |
221 | 240 | # discover: false |
222 | 241 | # issuer: "https://github.com/" |
223 | 242 | # client_id: "your-client-id" # TO BE FILLED |
239 | 258 | # jsonschema definition of the configuration settings for an oidc identity provider |
240 | 259 | OIDC_PROVIDER_CONFIG_SCHEMA = { |
241 | 260 | "type": "object", |
242 | "required": ["issuer", "client_id", "client_secret"], | |
261 | "required": ["issuer", "client_id"], | |
243 | 262 | "properties": { |
244 | 263 | "idp_id": { |
245 | 264 | "type": "string", |
252 | 271 | "idp_icon": {"type": "string"}, |
253 | 272 | "idp_brand": { |
254 | 273 | "type": "string", |
255 | # MSC2758-style namespaced identifier | |
274 | "minLength": 1, | |
275 | "maxLength": 255, | |
276 | "pattern": "^[a-z][a-z0-9_.-]*$", | |
277 | }, | |
278 | "idp_unstable_brand": { | |
279 | "type": "string", | |
256 | 280 | "minLength": 1, |
257 | 281 | "maxLength": 255, |
258 | 282 | "pattern": "^[a-z][a-z0-9_.-]*$", |
261 | 285 | "issuer": {"type": "string"}, |
262 | 286 | "client_id": {"type": "string"}, |
263 | 287 | "client_secret": {"type": "string"}, |
288 | "client_secret_jwt_key": { | |
289 | "type": "object", | |
290 | "required": ["jwt_header"], | |
291 | "oneOf": [ | |
292 | {"required": ["key"]}, | |
293 | {"required": ["key_file"]}, | |
294 | ], | |
295 | "properties": { | |
296 | "key": {"type": "string"}, | |
297 | "key_file": {"type": "string"}, | |
298 | "jwt_header": { | |
299 | "type": "object", | |
300 | "required": ["alg"], | |
301 | "properties": { | |
302 | "alg": {"type": "string"}, | |
303 | }, | |
304 | "additionalProperties": {"type": "string"}, | |
305 | }, | |
306 | "jwt_payload": { | |
307 | "type": "object", | |
308 | "additionalProperties": {"type": "string"}, | |
309 | }, | |
310 | }, | |
311 | }, | |
264 | 312 | "client_auth_method": { |
265 | 313 | "type": "string", |
266 | 314 | # the following list is the same as the keys of |
403 | 451 | "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) |
404 | 452 | ) from e |
405 | 453 | |
454 | client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key") | |
455 | client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey] | |
456 | if client_secret_jwt_key_config is not None: | |
457 | keyfile = client_secret_jwt_key_config.get("key_file") | |
458 | if keyfile: | |
459 | key = read_file(keyfile, config_path + ("client_secret_jwt_key",)) | |
460 | else: | |
461 | key = client_secret_jwt_key_config["key"] | |
462 | client_secret_jwt_key = OidcProviderClientSecretJwtKey( | |
463 | key=key, | |
464 | jwt_header=client_secret_jwt_key_config["jwt_header"], | |
465 | jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}), | |
466 | ) | |
467 | ||
406 | 468 | return OidcProviderConfig( |
407 | 469 | idp_id=idp_id, |
408 | 470 | idp_name=oidc_config.get("idp_name", "OIDC"), |
409 | 471 | idp_icon=idp_icon, |
410 | 472 | idp_brand=oidc_config.get("idp_brand"), |
473 | unstable_idp_brand=oidc_config.get("unstable_idp_brand"), | |
411 | 474 | discover=oidc_config.get("discover", True), |
412 | 475 | issuer=oidc_config["issuer"], |
413 | 476 | client_id=oidc_config["client_id"], |
414 | client_secret=oidc_config["client_secret"], | |
477 | client_secret=oidc_config.get("client_secret"), | |
478 | client_secret_jwt_key=client_secret_jwt_key, | |
415 | 479 | client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), |
416 | 480 | scopes=oidc_config.get("scopes", ["openid"]), |
417 | 481 | authorization_endpoint=oidc_config.get("authorization_endpoint"), |
427 | 491 | |
428 | 492 | |
429 | 493 | @attr.s(slots=True, frozen=True) |
494 | class OidcProviderClientSecretJwtKey: | |
495 | # a pem-encoded signing key | |
496 | key = attr.ib(type=str) | |
497 | ||
498 | # properties to include in the JWT header | |
499 | jwt_header = attr.ib(type=Mapping[str, str]) | |
500 | ||
501 | # properties to include in the JWT payload. | |
502 | jwt_payload = attr.ib(type=Mapping[str, str]) | |
503 | ||
504 | ||
505 | @attr.s(slots=True, frozen=True) | |
430 | 506 | class OidcProviderConfig: |
431 | 507 | # a unique identifier for this identity provider. Used in the 'user_external_ids' |
432 | 508 | # table, as well as the query/path parameter used in the login protocol. |
441 | 517 | # Optional brand identifier for this IdP. |
442 | 518 | idp_brand = attr.ib(type=Optional[str]) |
443 | 519 | |
520 | # Optional brand identifier for the unstable API (see MSC2858). | |
521 | unstable_idp_brand = attr.ib(type=Optional[str]) | |
522 | ||
444 | 523 | # whether the OIDC discovery mechanism is used to discover endpoints |
445 | 524 | discover = attr.ib(type=bool) |
446 | 525 | |
451 | 530 | # oauth2 client id to use |
452 | 531 | client_id = attr.ib(type=str) |
453 | 532 | |
454 | # oauth2 client secret to use | |
455 | client_secret = attr.ib(type=str) | |
533 | # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate | |
534 | # a secret. | |
535 | client_secret = attr.ib(type=Optional[str]) | |
536 | ||
537 | # key to use to construct a JWT to use as a client secret. May be `None` if | |
538 | # `client_secret` is set. | |
539 | client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey]) | |
456 | 540 | |
457 | 541 | # auth method to use when exchanging the token. |
458 | 542 | # Valid values are 'client_secret_basic', 'client_secret_post' and |
840 | 840 | # Whether to require authentication to retrieve profile data (avatars, |
841 | 841 | # display names) of other users through the client API. Defaults to |
842 | 842 | # 'false'. Note that profile data is also available via the federation |
843 | # API, so this setting is of limited value if federation is enabled on | |
844 | # the server. | |
843 | # API, unless allow_profile_lookup_over_federation is set to false. | |
845 | 844 | # |
846 | 845 | #require_auth_for_profile_requests: true |
847 | 846 |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | import sys | |
15 | import logging | |
16 | 16 | |
17 | 17 | from ._base import Config |
18 | ||
19 | ROOM_STATS_DISABLED_WARN = """\ | |
20 | WARNING: room/user statistics have been disabled via the stats.enabled | |
21 | configuration setting. This means that certain features (such as the room | |
22 | directory) will not operate correctly. Future versions of Synapse may ignore | |
23 | this setting. | |
24 | ||
25 | To fix this warning, remove the stats.enabled setting from your configuration | |
26 | file. | |
27 | --------------------------------------------------------------------------------""" | |
28 | ||
29 | logger = logging.getLogger(__name__) | |
18 | 30 | |
19 | 31 | |
20 | 32 | class StatsConfig(Config): |
27 | 39 | def read_config(self, config, **kwargs): |
28 | 40 | self.stats_enabled = True |
29 | 41 | self.stats_bucket_size = 86400 * 1000 |
30 | self.stats_retention = sys.maxsize | |
31 | 42 | stats_config = config.get("stats", None) |
32 | 43 | if stats_config: |
33 | 44 | self.stats_enabled = stats_config.get("enabled", self.stats_enabled) |
34 | 45 | self.stats_bucket_size = self.parse_duration( |
35 | 46 | stats_config.get("bucket_size", "1d") |
36 | 47 | ) |
37 | self.stats_retention = self.parse_duration( | |
38 | stats_config.get("retention", "%ds" % (sys.maxsize,)) | |
39 | ) | |
48 | if not self.stats_enabled: | |
49 | logger.warning(ROOM_STATS_DISABLED_WARN) | |
40 | 50 | |
41 | 51 | def generate_config_section(self, config_dir_path, server_name, **kwargs): |
42 | 52 | return """ |
43 | # Local statistics collection. Used in populating the room directory. | |
53 | # Settings for local room and user statistics collection. See | |
54 | # docs/room_and_user_statistics.md. | |
44 | 55 | # |
45 | # 'bucket_size' controls how large each statistics timeslice is. It can | |
46 | # be defined in a human readable short form -- e.g. "1d", "1y". | |
47 | # | |
48 | # 'retention' controls how long historical statistics will be kept for. | |
49 | # It can be defined in a human readable short form -- e.g. "1d", "1y". | |
50 | # | |
51 | # | |
52 | #stats: | |
53 | # enabled: true | |
54 | # bucket_size: 1d | |
55 | # retention: 1y | |
56 | stats: | |
57 | # Uncomment the following to disable room and user statistics. Note that doing | |
58 | # so may cause certain features (such as the room directory) not to work | |
59 | # correctly. | |
60 | # | |
61 | #enabled: false | |
62 | ||
63 | # The size of each timeslice in the room_stats_historical and | |
64 | # user_stats_historical tables, as a time period. Defaults to "1d". | |
65 | # | |
66 | #bucket_size: 1h | |
56 | 67 | """ |
14 | 14 | # limitations under the License. |
15 | 15 | |
16 | 16 | import inspect |
17 | import logging | |
17 | 18 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union |
18 | 19 | |
19 | 20 | from synapse.rest.media.v1._base import FileInfo |
25 | 26 | if TYPE_CHECKING: |
26 | 27 | import synapse.events |
27 | 28 | import synapse.server |
29 | ||
30 | logger = logging.getLogger(__name__) | |
28 | 31 | |
29 | 32 | |
30 | 33 | class SpamChecker: |
189 | 192 | email_threepid: Optional[dict], |
190 | 193 | username: Optional[str], |
191 | 194 | request_info: Collection[Tuple[str, str]], |
195 | auth_provider_id: Optional[str] = None, | |
192 | 196 | ) -> RegistrationBehaviour: |
193 | 197 | """Checks if we should allow the given registration request. |
194 | 198 | |
197 | 201 | username: The request user name, if any |
198 | 202 | request_info: List of tuples of user agent and IP that |
199 | 203 | were used during the registration process. |
204 | auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml", | |
205 | "cas". If any. Note this does not include users registered | |
206 | via a password provider. | |
200 | 207 | |
201 | 208 | Returns: |
202 | 209 | Enum for how the request should be handled |
207 | 214 | # spam checker |
208 | 215 | checker = getattr(spam_checker, "check_registration_for_spam", None) |
209 | 216 | if checker: |
210 | behaviour = await maybe_awaitable( | |
211 | checker(email_threepid, username, request_info) | |
212 | ) | |
217 | # Provide auth_provider_id if the function supports it | |
218 | checker_args = inspect.signature(checker) | |
219 | if len(checker_args.parameters) == 4: | |
220 | d = checker( | |
221 | email_threepid, | |
222 | username, | |
223 | request_info, | |
224 | auth_provider_id, | |
225 | ) | |
226 | elif len(checker_args.parameters) == 3: | |
227 | d = checker(email_threepid, username, request_info) | |
228 | else: | |
229 | logger.error( | |
230 | "Invalid signature for %s.check_registration_for_spam. Denying registration", | |
231 | spam_checker.__module__, | |
232 | ) | |
233 | return RegistrationBehaviour.DENY | |
234 | ||
235 | behaviour = await maybe_awaitable(d) | |
213 | 236 | assert isinstance(behaviour, RegistrationBehaviour) |
214 | 237 | if behaviour != RegistrationBehaviour.ALLOW: |
215 | 238 | return behaviour |
21 | 21 | Awaitable, |
22 | 22 | Callable, |
23 | 23 | Dict, |
24 | Iterable, | |
24 | 25 | List, |
25 | 26 | Optional, |
26 | 27 | Tuple, |
89 | 90 | "Time taken to process an event", |
90 | 91 | ) |
91 | 92 | |
92 | ||
93 | last_pdu_age_metric = Gauge( | |
94 | "synapse_federation_last_received_pdu_age", | |
95 | "The age (in seconds) of the last PDU successfully received from the given domain", | |
93 | last_pdu_ts_metric = Gauge( | |
94 | "synapse_federation_last_received_pdu_time", | |
95 | "The timestamp of the last PDU which was successfully received from the given domain", | |
96 | 96 | labelnames=("server_name",), |
97 | 97 | ) |
98 | 98 | |
99 | 99 | |
100 | 100 | class FederationServer(FederationBase): |
101 | def __init__(self, hs): | |
101 | def __init__(self, hs: "HomeServer"): | |
102 | 102 | super().__init__(hs) |
103 | 103 | |
104 | 104 | self.auth = hs.get_auth() |
111 | 111 | # with FederationHandlerRegistry. |
112 | 112 | hs.get_directory_handler() |
113 | 113 | |
114 | self._federation_ratelimiter = hs.get_federation_ratelimiter() | |
115 | ||
116 | 114 | self._server_linearizer = Linearizer("fed_server") |
117 | self._transaction_linearizer = Linearizer("fed_txn_handler") | |
115 | ||
116 | # origins that we are currently processing a transaction from. | |
117 | # a dict from origin to txn id. | |
118 | self._active_transactions = {} # type: Dict[str, str] | |
118 | 119 | |
119 | 120 | # We cache results for transaction with the same ID |
120 | 121 | self._transaction_resp_cache = ResponseCache( |
121 | hs, "fed_txn_handler", timeout_ms=30000 | |
122 | hs.get_clock(), "fed_txn_handler", timeout_ms=30000 | |
122 | 123 | ) # type: ResponseCache[Tuple[str, str]] |
123 | 124 | |
124 | 125 | self.transaction_actions = TransactionActions(self.store) |
128 | 129 | # We cache responses to state queries, as they take a while and often |
129 | 130 | # come in waves. |
130 | 131 | self._state_resp_cache = ResponseCache( |
131 | hs, "state_resp", timeout_ms=30000 | |
132 | hs.get_clock(), "state_resp", timeout_ms=30000 | |
132 | 133 | ) # type: ResponseCache[Tuple[str, str]] |
133 | 134 | self._state_ids_resp_cache = ResponseCache( |
134 | hs, "state_ids_resp", timeout_ms=30000 | |
135 | hs.get_clock(), "state_ids_resp", timeout_ms=30000 | |
135 | 136 | ) # type: ResponseCache[Tuple[str, str]] |
136 | 137 | |
137 | 138 | self._federation_metrics_domains = ( |
167 | 168 | raise Exception("Transaction missing transaction_id") |
168 | 169 | |
169 | 170 | logger.debug("[%s] Got transaction", transaction_id) |
171 | ||
172 | # Reject malformed transactions early: reject if too many PDUs/EDUs | |
173 | if len(transaction.pdus) > 50 or ( # type: ignore | |
174 | hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore | |
175 | ): | |
176 | logger.info("Transaction PDU or EDU count too large. Returning 400") | |
177 | return 400, {} | |
178 | ||
179 | # we only process one transaction from each origin at a time. We need to do | |
180 | # this check here, rather than in _on_incoming_transaction_inner so that we | |
181 | # don't cache the rejection in _transaction_resp_cache (so that if the txn | |
182 | # arrives again later, we can process it). | |
183 | current_transaction = self._active_transactions.get(origin) | |
184 | if current_transaction and current_transaction != transaction_id: | |
185 | logger.warning( | |
186 | "Received another txn %s from %s while still processing %s", | |
187 | transaction_id, | |
188 | origin, | |
189 | current_transaction, | |
190 | ) | |
191 | return 429, { | |
192 | "errcode": Codes.UNKNOWN, | |
193 | "error": "Too many concurrent transactions", | |
194 | } | |
195 | ||
196 | # CRITICAL SECTION: we must now not await until we populate _active_transactions | |
197 | # in _on_incoming_transaction_inner. | |
170 | 198 | |
171 | 199 | # We wrap in a ResponseCache so that we de-duplicate retried |
172 | 200 | # transactions. |
181 | 209 | async def _on_incoming_transaction_inner( |
182 | 210 | self, origin: str, transaction: Transaction, request_time: int |
183 | 211 | ) -> Tuple[int, Dict[str, Any]]: |
184 | # Use a linearizer to ensure that transactions from a remote are | |
185 | # processed in order. | |
186 | with await self._transaction_linearizer.queue(origin): | |
187 | # We rate limit here *after* we've queued up the incoming requests, | |
188 | # so that we don't fill up the ratelimiter with blocked requests. | |
189 | # | |
190 | # This is important as the ratelimiter allows N concurrent requests | |
191 | # at a time, and only starts ratelimiting if there are more requests | |
192 | # than that being processed at a time. If we queued up requests in | |
193 | # the linearizer/response cache *after* the ratelimiting then those | |
194 | # queued up requests would count as part of the allowed limit of N | |
195 | # concurrent requests. | |
196 | with self._federation_ratelimiter.ratelimit(origin) as d: | |
197 | await d | |
198 | ||
199 | result = await self._handle_incoming_transaction( | |
200 | origin, transaction, request_time | |
201 | ) | |
202 | ||
203 | return result | |
212 | # CRITICAL SECTION: the first thing we must do (before awaiting) is | |
213 | # add an entry to _active_transactions. | |
214 | assert origin not in self._active_transactions | |
215 | self._active_transactions[origin] = transaction.transaction_id # type: ignore | |
216 | ||
217 | try: | |
218 | result = await self._handle_incoming_transaction( | |
219 | origin, transaction, request_time | |
220 | ) | |
221 | return result | |
222 | finally: | |
223 | del self._active_transactions[origin] | |
204 | 224 | |
205 | 225 | async def _handle_incoming_transaction( |
206 | 226 | self, origin: str, transaction: Transaction, request_time: int |
225 | 245 | return response |
226 | 246 | |
227 | 247 | logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore |
228 | ||
229 | # Reject if PDU count > 50 or EDU count > 100 | |
230 | if len(transaction.pdus) > 50 or ( # type: ignore | |
231 | hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore | |
232 | ): | |
233 | ||
234 | logger.info("Transaction PDU or EDU count too large. Returning 400") | |
235 | ||
236 | response = {} | |
237 | await self.transaction_actions.set_response( | |
238 | origin, transaction, 400, response | |
239 | ) | |
240 | return 400, response | |
241 | 248 | |
242 | 249 | # We process PDUs and EDUs in parallel. This is important as we don't |
243 | 250 | # want to block things like to device messages from reaching clients |
334 | 341 | # impose a limit to avoid going too crazy with ram/cpu. |
335 | 342 | |
336 | 343 | async def process_pdus_for_room(room_id: str): |
337 | logger.debug("Processing PDUs for %s", room_id) | |
338 | try: | |
339 | await self.check_server_matches_acl(origin_host, room_id) | |
340 | except AuthError as e: | |
341 | logger.warning("Ignoring PDUs for room %s from banned server", room_id) | |
344 | with nested_logging_context(room_id): | |
345 | logger.debug("Processing PDUs for %s", room_id) | |
346 | ||
347 | try: | |
348 | await self.check_server_matches_acl(origin_host, room_id) | |
349 | except AuthError as e: | |
350 | logger.warning( | |
351 | "Ignoring PDUs for room %s from banned server", room_id | |
352 | ) | |
353 | for pdu in pdus_by_room[room_id]: | |
354 | event_id = pdu.event_id | |
355 | pdu_results[event_id] = e.error_dict() | |
356 | return | |
357 | ||
342 | 358 | for pdu in pdus_by_room[room_id]: |
343 | event_id = pdu.event_id | |
344 | pdu_results[event_id] = e.error_dict() | |
345 | return | |
346 | ||
347 | for pdu in pdus_by_room[room_id]: | |
348 | event_id = pdu.event_id | |
349 | with pdu_process_time.time(): | |
350 | with nested_logging_context(event_id): | |
351 | try: | |
352 | await self._handle_received_pdu(origin, pdu) | |
353 | pdu_results[event_id] = {} | |
354 | except FederationError as e: | |
355 | logger.warning("Error handling PDU %s: %s", event_id, e) | |
356 | pdu_results[event_id] = {"error": str(e)} | |
357 | except Exception as e: | |
358 | f = failure.Failure() | |
359 | pdu_results[event_id] = {"error": str(e)} | |
360 | logger.error( | |
361 | "Failed to handle PDU %s", | |
362 | event_id, | |
363 | exc_info=(f.type, f.value, f.getTracebackObject()), | |
364 | ) | |
359 | pdu_results[pdu.event_id] = await process_pdu(pdu) | |
360 | ||
361 | async def process_pdu(pdu: EventBase) -> JsonDict: | |
362 | event_id = pdu.event_id | |
363 | with pdu_process_time.time(): | |
364 | with nested_logging_context(event_id): | |
365 | try: | |
366 | await self._handle_received_pdu(origin, pdu) | |
367 | return {} | |
368 | except FederationError as e: | |
369 | logger.warning("Error handling PDU %s: %s", event_id, e) | |
370 | return {"error": str(e)} | |
371 | except Exception as e: | |
372 | f = failure.Failure() | |
373 | logger.error( | |
374 | "Failed to handle PDU %s", | |
375 | event_id, | |
376 | exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore | |
377 | ) | |
378 | return {"error": str(e)} | |
365 | 379 | |
366 | 380 | await concurrently_execute( |
367 | 381 | process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT |
368 | 382 | ) |
369 | 383 | |
370 | 384 | if newest_pdu_ts and origin in self._federation_metrics_domains: |
371 | newest_pdu_age = self._clock.time_msec() - newest_pdu_ts | |
372 | last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000) | |
385 | last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000) | |
373 | 386 | |
374 | 387 | return pdu_results |
375 | 388 | |
447 | 460 | |
448 | 461 | async def _on_state_ids_request_compute(self, room_id, event_id): |
449 | 462 | state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) |
450 | auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) | |
463 | auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) | |
451 | 464 | return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} |
452 | 465 | |
453 | 466 | async def _on_context_state_request_compute( |
454 | 467 | self, room_id: str, event_id: str |
455 | 468 | ) -> Dict[str, list]: |
456 | 469 | if event_id: |
457 | pdus = await self.handler.get_state_for_pdu(room_id, event_id) | |
470 | pdus = await self.handler.get_state_for_pdu( | |
471 | room_id, event_id | |
472 | ) # type: Iterable[EventBase] | |
458 | 473 | else: |
459 | 474 | pdus = (await self.state.get_current_state(room_id)).values() |
460 | 475 | |
461 | auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) | |
476 | auth_chain = await self.store.get_auth_chain( | |
477 | room_id, [pdu.event_id for pdu in pdus] | |
478 | ) | |
462 | 479 | |
463 | 480 | return { |
464 | 481 | "pdus": [pdu.get_pdu_json() for pdu in pdus], |
862 | 879 | self.edu_handlers = ( |
863 | 880 | {} |
864 | 881 | ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] |
865 | self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] | |
882 | self.query_handlers = ( | |
883 | {} | |
884 | ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]] | |
866 | 885 | |
867 | 886 | # Map from type to instance names that we should route EDU handling to. |
868 | 887 | # We randomly choose one instance from the list to route to for each new |
896 | 915 | self.edu_handlers[edu_type] = handler |
897 | 916 | |
898 | 917 | def register_query_handler( |
899 | self, query_type: str, handler: Callable[[dict], defer.Deferred] | |
918 | self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] | |
900 | 919 | ): |
901 | 920 | """Sets the handler callable that will be used to handle an incoming |
902 | 921 | federation query of the given type. |
969 | 988 | # Oh well, let's just log and move on. |
970 | 989 | logger.warning("No handler registered for EDU type %s", edu_type) |
971 | 990 | |
972 | async def on_query(self, query_type: str, args: dict): | |
991 | async def on_query(self, query_type: str, args: dict) -> JsonDict: | |
973 | 992 | handler = self.query_handlers.get(query_type) |
974 | 993 | if handler: |
975 | 994 | return await handler(args) |
16 | 16 | import logging |
17 | 17 | from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast |
18 | 18 | |
19 | import attr | |
19 | 20 | from prometheus_client import Counter |
20 | 21 | |
21 | 22 | from synapse.api.errors import ( |
92 | 93 | self._destination = destination |
93 | 94 | self.transmission_loop_running = False |
94 | 95 | |
96 | # Flag to signal to any running transmission loop that there is new data | |
97 | # queued up to be sent. | |
98 | self._new_data_to_send = False | |
99 | ||
95 | 100 | # True whilst we are sending events that the remote homeserver missed |
96 | 101 | # because it was unreachable. We start in this state so we can perform |
97 | 102 | # catch-up at startup. |
107 | 112 | # destination (we are the only updater so this is safe) |
108 | 113 | self._last_successful_stream_ordering = None # type: Optional[int] |
109 | 114 | |
110 | # a list of pending PDUs | |
115 | # a queue of pending PDUs | |
111 | 116 | self._pending_pdus = [] # type: List[EventBase] |
112 | 117 | |
113 | 118 | # XXX this is never actually used: see |
207 | 212 | transaction in the background. |
208 | 213 | """ |
209 | 214 | |
215 | # Mark that we (may) have new things to send, so that any running | |
216 | # transmission loop will recheck whether there is stuff to send. | |
217 | self._new_data_to_send = True | |
218 | ||
210 | 219 | if self.transmission_loop_running: |
211 | 220 | # XXX: this can get stuck on by a never-ending |
212 | 221 | # request at which point pending_pdus just keeps growing. |
249 | 258 | |
250 | 259 | pending_pdus = [] |
251 | 260 | while True: |
252 | # We have to keep 2 free slots for presence and rr_edus | |
253 | limit = MAX_EDUS_PER_TRANSACTION - 2 | |
254 | ||
255 | device_update_edus, dev_list_id = await self._get_device_update_edus( | |
256 | limit | |
257 | ) | |
258 | ||
259 | limit -= len(device_update_edus) | |
260 | ||
261 | ( | |
262 | to_device_edus, | |
263 | device_stream_id, | |
264 | ) = await self._get_to_device_message_edus(limit) | |
265 | ||
266 | pending_edus = device_update_edus + to_device_edus | |
267 | ||
268 | # BEGIN CRITICAL SECTION | |
269 | # | |
270 | # In order to avoid a race condition, we need to make sure that | |
271 | # the following code (from popping the queues up to the point | |
272 | # where we decide if we actually have any pending messages) is | |
273 | # atomic - otherwise new PDUs or EDUs might arrive in the | |
274 | # meantime, but not get sent because we hold the | |
275 | # transmission_loop_running flag. | |
276 | ||
277 | pending_pdus = self._pending_pdus | |
278 | ||
279 | # We can only include at most 50 PDUs per transactions | |
280 | pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:] | |
281 | ||
282 | pending_edus.extend(self._get_rr_edus(force_flush=False)) | |
283 | pending_presence = self._pending_presence | |
284 | self._pending_presence = {} | |
285 | if pending_presence: | |
286 | pending_edus.append( | |
287 | Edu( | |
288 | origin=self._server_name, | |
289 | destination=self._destination, | |
290 | edu_type="m.presence", | |
291 | content={ | |
292 | "push": [ | |
293 | format_user_presence_state( | |
294 | presence, self._clock.time_msec() | |
295 | ) | |
296 | for presence in pending_presence.values() | |
297 | ] | |
298 | }, | |
261 | self._new_data_to_send = False | |
262 | ||
263 | async with _TransactionQueueManager(self) as ( | |
264 | pending_pdus, | |
265 | pending_edus, | |
266 | ): | |
267 | if not pending_pdus and not pending_edus: | |
268 | logger.debug("TX [%s] Nothing to send", self._destination) | |
269 | ||
270 | # If we've gotten told about new things to send during | |
271 | # checking for things to send, we try looking again. | |
272 | # Otherwise new PDUs or EDUs might arrive in the meantime, | |
273 | # but not get sent because we hold the | |
274 | # `transmission_loop_running` flag. | |
275 | if self._new_data_to_send: | |
276 | continue | |
277 | else: | |
278 | return | |
279 | ||
280 | if pending_pdus: | |
281 | logger.debug( | |
282 | "TX [%s] len(pending_pdus_by_dest[dest]) = %d", | |
283 | self._destination, | |
284 | len(pending_pdus), | |
299 | 285 | ) |
286 | ||
287 | await self._transaction_manager.send_new_transaction( | |
288 | self._destination, pending_pdus, pending_edus | |
300 | 289 | ) |
301 | 290 | |
302 | pending_edus.extend( | |
303 | self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) | |
304 | ) | |
305 | while ( | |
306 | len(pending_edus) < MAX_EDUS_PER_TRANSACTION | |
307 | and self._pending_edus_keyed | |
308 | ): | |
309 | _, val = self._pending_edus_keyed.popitem() | |
310 | pending_edus.append(val) | |
311 | ||
312 | if pending_pdus: | |
313 | logger.debug( | |
314 | "TX [%s] len(pending_pdus_by_dest[dest]) = %d", | |
315 | self._destination, | |
316 | len(pending_pdus), | |
317 | ) | |
318 | ||
319 | if not pending_pdus and not pending_edus: | |
320 | logger.debug("TX [%s] Nothing to send", self._destination) | |
321 | self._last_device_stream_id = device_stream_id | |
322 | return | |
323 | ||
324 | # if we've decided to send a transaction anyway, and we have room, we | |
325 | # may as well send any pending RRs | |
326 | if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: | |
327 | pending_edus.extend(self._get_rr_edus(force_flush=True)) | |
328 | ||
329 | # END CRITICAL SECTION | |
330 | ||
331 | success = await self._transaction_manager.send_new_transaction( | |
332 | self._destination, pending_pdus, pending_edus | |
333 | ) | |
334 | if success: | |
335 | 291 | sent_transactions_counter.inc() |
336 | 292 | sent_edus_counter.inc(len(pending_edus)) |
337 | 293 | for edu in pending_edus: |
338 | 294 | sent_edus_by_type.labels(edu.edu_type).inc() |
339 | # Remove the acknowledged device messages from the database | |
340 | # Only bother if we actually sent some device messages | |
341 | if to_device_edus: | |
342 | await self._store.delete_device_msgs_for_remote( | |
343 | self._destination, device_stream_id | |
344 | ) | |
345 | ||
346 | # also mark the device updates as sent | |
347 | if device_update_edus: | |
348 | logger.info( | |
349 | "Marking as sent %r %r", self._destination, dev_list_id | |
350 | ) | |
351 | await self._store.mark_as_sent_devices_by_remote( | |
352 | self._destination, dev_list_id | |
353 | ) | |
354 | ||
355 | self._last_device_stream_id = device_stream_id | |
356 | self._last_device_list_stream_id = dev_list_id | |
357 | ||
358 | if pending_pdus: | |
359 | # we sent some PDUs and it was successful, so update our | |
360 | # last_successful_stream_ordering in the destinations table. | |
361 | final_pdu = pending_pdus[-1] | |
362 | last_successful_stream_ordering = ( | |
363 | final_pdu.internal_metadata.stream_ordering | |
364 | ) | |
365 | assert last_successful_stream_ordering | |
366 | await self._store.set_destination_last_successful_stream_ordering( | |
367 | self._destination, last_successful_stream_ordering | |
368 | ) | |
369 | else: | |
370 | break | |
295 | ||
371 | 296 | except NotRetryingDestination as e: |
372 | 297 | logger.debug( |
373 | 298 | "TX [%s] not ready for retry yet (next retry at %s) - " |
400 | 325 | self._pending_presence = {} |
401 | 326 | self._pending_rrs = {} |
402 | 327 | |
403 | self._start_catching_up() | |
328 | self._start_catching_up() | |
404 | 329 | except FederationDeniedError as e: |
405 | 330 | logger.info(e) |
406 | 331 | except HttpResponseException as e: |
411 | 336 | e, |
412 | 337 | ) |
413 | 338 | |
414 | self._start_catching_up() | |
415 | 339 | except RequestSendFailed as e: |
416 | 340 | logger.warning( |
417 | 341 | "TX [%s] Failed to send transaction: %s", self._destination, e |
421 | 345 | logger.info( |
422 | 346 | "Failed to send event %s to %s", p.event_id, self._destination |
423 | 347 | ) |
424 | ||
425 | self._start_catching_up() | |
426 | 348 | except Exception: |
427 | 349 | logger.exception("TX [%s] Failed to send transaction", self._destination) |
428 | 350 | for p in pending_pdus: |
429 | 351 | logger.info( |
430 | 352 | "Failed to send event %s to %s", p.event_id, self._destination |
431 | 353 | ) |
432 | ||
433 | self._start_catching_up() | |
434 | 354 | finally: |
435 | 355 | # We want to be *very* sure we clear this after we stop processing |
436 | 356 | self.transmission_loop_running = False |
498 | 418 | rooms = [p.room_id for p in catchup_pdus] |
499 | 419 | logger.info("Catching up rooms to %s: %r", self._destination, rooms) |
500 | 420 | |
501 | success = await self._transaction_manager.send_new_transaction( | |
421 | await self._transaction_manager.send_new_transaction( | |
502 | 422 | self._destination, catchup_pdus, [] |
503 | 423 | ) |
504 | ||
505 | if not success: | |
506 | return | |
507 | 424 | |
508 | 425 | sent_transactions_counter.inc() |
509 | 426 | final_pdu = catchup_pdus[-1] |
583 | 500 | """ |
584 | 501 | self._catching_up = True |
585 | 502 | self._pending_pdus = [] |
503 | ||
504 | ||
505 | @attr.s(slots=True) | |
506 | class _TransactionQueueManager: | |
507 | """A helper async context manager for pulling stuff off the queues and | |
508 | tracking what was last successfully sent, etc. | |
509 | """ | |
510 | ||
511 | queue = attr.ib(type=PerDestinationQueue) | |
512 | ||
513 | _device_stream_id = attr.ib(type=Optional[int], default=None) | |
514 | _device_list_id = attr.ib(type=Optional[int], default=None) | |
515 | _last_stream_ordering = attr.ib(type=Optional[int], default=None) | |
516 | _pdus = attr.ib(type=List[EventBase], factory=list) | |
517 | ||
518 | async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: | |
519 | # First we calculate the EDUs we want to send, if any. | |
520 | ||
521 | # We start by fetching device related EDUs, i.e device updates and to | |
522 | # device messages. We have to keep 2 free slots for presence and rr_edus. | |
523 | limit = MAX_EDUS_PER_TRANSACTION - 2 | |
524 | ||
525 | device_update_edus, dev_list_id = await self.queue._get_device_update_edus( | |
526 | limit | |
527 | ) | |
528 | ||
529 | if device_update_edus: | |
530 | self._device_list_id = dev_list_id | |
531 | else: | |
532 | self.queue._last_device_list_stream_id = dev_list_id | |
533 | ||
534 | limit -= len(device_update_edus) | |
535 | ||
536 | ( | |
537 | to_device_edus, | |
538 | device_stream_id, | |
539 | ) = await self.queue._get_to_device_message_edus(limit) | |
540 | ||
541 | if to_device_edus: | |
542 | self._device_stream_id = device_stream_id | |
543 | else: | |
544 | self.queue._last_device_stream_id = device_stream_id | |
545 | ||
546 | pending_edus = device_update_edus + to_device_edus | |
547 | ||
548 | # Now add the read receipt EDU. | |
549 | pending_edus.extend(self.queue._get_rr_edus(force_flush=False)) | |
550 | ||
551 | # And presence EDU. | |
552 | if self.queue._pending_presence: | |
553 | pending_edus.append( | |
554 | Edu( | |
555 | origin=self.queue._server_name, | |
556 | destination=self.queue._destination, | |
557 | edu_type="m.presence", | |
558 | content={ | |
559 | "push": [ | |
560 | format_user_presence_state( | |
561 | presence, self.queue._clock.time_msec() | |
562 | ) | |
563 | for presence in self.queue._pending_presence.values() | |
564 | ] | |
565 | }, | |
566 | ) | |
567 | ) | |
568 | self.queue._pending_presence = {} | |
569 | ||
570 | # Finally add any other types of EDUs if there is room. | |
571 | pending_edus.extend( | |
572 | self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) | |
573 | ) | |
574 | while ( | |
575 | len(pending_edus) < MAX_EDUS_PER_TRANSACTION | |
576 | and self.queue._pending_edus_keyed | |
577 | ): | |
578 | _, val = self.queue._pending_edus_keyed.popitem() | |
579 | pending_edus.append(val) | |
580 | ||
581 | # Now we look for any PDUs to send, by getting up to 50 PDUs from the | |
582 | # queue | |
583 | self._pdus = self.queue._pending_pdus[:50] | |
584 | ||
585 | if not self._pdus and not pending_edus: | |
586 | return [], [] | |
587 | ||
588 | # if we've decided to send a transaction anyway, and we have room, we | |
589 | # may as well send any pending RRs | |
590 | if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: | |
591 | pending_edus.extend(self.queue._get_rr_edus(force_flush=True)) | |
592 | ||
593 | if self._pdus: | |
594 | self._last_stream_ordering = self._pdus[ | |
595 | -1 | |
596 | ].internal_metadata.stream_ordering | |
597 | assert self._last_stream_ordering | |
598 | ||
599 | return self._pdus, pending_edus | |
600 | ||
601 | async def __aexit__(self, exc_type, exc, tb): | |
602 | if exc_type is not None: | |
603 | # Failed to send transaction, so we bail out. | |
604 | return | |
605 | ||
606 | # Successfully sent transactions, so we remove pending PDUs from the queue | |
607 | if self._pdus: | |
608 | self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :] | |
609 | ||
610 | # Succeeded to send the transaction so we record where we have sent up | |
611 | # to in the various streams | |
612 | ||
613 | if self._device_stream_id: | |
614 | await self.queue._store.delete_device_msgs_for_remote( | |
615 | self.queue._destination, self._device_stream_id | |
616 | ) | |
617 | self.queue._last_device_stream_id = self._device_stream_id | |
618 | ||
619 | # also mark the device updates as sent | |
620 | if self._device_list_id: | |
621 | logger.info( | |
622 | "Marking as sent %r %r", self.queue._destination, self._device_list_id | |
623 | ) | |
624 | await self.queue._store.mark_as_sent_devices_by_remote( | |
625 | self.queue._destination, self._device_list_id | |
626 | ) | |
627 | self.queue._last_device_list_stream_id = self._device_list_id | |
628 | ||
629 | if self._last_stream_ordering: | |
630 | # we sent some PDUs and it was successful, so update our | |
631 | # last_successful_stream_ordering in the destinations table. | |
632 | await self.queue._store.set_destination_last_successful_stream_ordering( | |
633 | self.queue._destination, self._last_stream_ordering | |
634 | ) |
35 | 35 | |
36 | 36 | logger = logging.getLogger(__name__) |
37 | 37 | |
38 | last_pdu_age_metric = Gauge( | |
39 | "synapse_federation_last_sent_pdu_age", | |
40 | "The age (in seconds) of the last PDU successfully sent to the given domain", | |
38 | last_pdu_ts_metric = Gauge( | |
39 | "synapse_federation_last_sent_pdu_time", | |
40 | "The timestamp of the last PDU which was successfully sent to the given domain", | |
41 | 41 | labelnames=("server_name",), |
42 | 42 | ) |
43 | 43 | |
68 | 68 | destination: str, |
69 | 69 | pdus: List[EventBase], |
70 | 70 | edus: List[Edu], |
71 | ) -> bool: | |
71 | ) -> None: | |
72 | 72 | """ |
73 | 73 | Args: |
74 | 74 | destination: The destination to send to (e.g. 'example.org') |
75 | 75 | pdus: In-order list of PDUs to send |
76 | 76 | edus: List of EDUs to send |
77 | ||
78 | Returns: | |
79 | True iff the transaction was successful | |
80 | 77 | """ |
81 | 78 | |
82 | 79 | # Make a transaction-sending opentracing span. This span follows on from |
95 | 92 | edu.strip_context() |
96 | 93 | |
97 | 94 | with start_active_span_follows_from("send_transaction", span_contexts): |
98 | success = True | |
99 | ||
100 | 95 | logger.debug("TX [%s] _attempt_new_transaction", destination) |
101 | 96 | |
102 | 97 | txn_id = str(self._next_txn_id) |
151 | 146 | response = await self._transport_layer.send_transaction( |
152 | 147 | transaction, json_data_cb |
153 | 148 | ) |
154 | code = 200 | |
155 | 149 | except HttpResponseException as e: |
156 | 150 | code = e.code |
157 | 151 | response = e.response |
158 | 152 | |
159 | if e.code in (401, 404, 429) or 500 <= e.code: | |
160 | logger.info( | |
161 | "TX [%s] {%s} got %d response", destination, txn_id, code | |
162 | ) | |
163 | raise e | |
153 | set_tag(tags.ERROR, True) | |
164 | 154 | |
165 | logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) | |
155 | logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) | |
156 | raise | |
166 | 157 | |
167 | if code == 200: | |
168 | for e_id, r in response.get("pdus", {}).items(): | |
169 | if "error" in r: | |
170 | logger.warning( | |
171 | "TX [%s] {%s} Remote returned error for %s: %s", | |
172 | destination, | |
173 | txn_id, | |
174 | e_id, | |
175 | r, | |
176 | ) | |
177 | else: | |
178 | for p in pdus: | |
158 | logger.info("TX [%s] {%s} got 200 response", destination, txn_id) | |
159 | ||
160 | for e_id, r in response.get("pdus", {}).items(): | |
161 | if "error" in r: | |
179 | 162 | logger.warning( |
180 | "TX [%s] {%s} Failed to send event %s", | |
163 | "TX [%s] {%s} Remote returned error for %s: %s", | |
181 | 164 | destination, |
182 | 165 | txn_id, |
183 | p.event_id, | |
166 | e_id, | |
167 | r, | |
184 | 168 | ) |
185 | success = False | |
186 | 169 | |
187 | if success and pdus and destination in self._federation_metrics_domains: | |
170 | if pdus and destination in self._federation_metrics_domains: | |
188 | 171 | last_pdu = pdus[-1] |
189 | last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts | |
190 | last_pdu_age_metric.labels(server_name=destination).set( | |
191 | last_pdu_age / 1000 | |
172 | last_pdu_ts_metric.labels(server_name=destination).set( | |
173 | last_pdu.origin_server_ts / 1000 | |
192 | 174 | ) |
193 | ||
194 | set_tag(tags.ERROR, not success) | |
195 | return success |
72 | 72 | "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port |
73 | 73 | ) |
74 | 74 | try: |
75 | self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host) | |
75 | self.reactor.listenTCP( | |
76 | self.hs.config.acme_port, srv, backlog=50, interface=host | |
77 | ) | |
76 | 78 | except twisted.internet.error.CannotListenError as e: |
77 | 79 | check_bind_error(e, host, bind_addresses) |
78 | 80 |
64 | 64 | from synapse.types import JsonDict, Requester, UserID |
65 | 65 | from synapse.util import stringutils as stringutils |
66 | 66 | from synapse.util.async_helpers import maybe_awaitable |
67 | from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry | |
67 | 68 | from synapse.util.msisdn import phone_number_to_msisdn |
68 | 69 | from synapse.util.threepids import canonicalise_email |
69 | 70 | |
169 | 170 | extra_attributes = attr.ib(type=JsonDict) |
170 | 171 | |
171 | 172 | |
173 | @attr.s(slots=True, frozen=True) | |
174 | class LoginTokenAttributes: | |
175 | """Data we store in a short-term login token""" | |
176 | ||
177 | user_id = attr.ib(type=str) | |
178 | ||
179 | # the SSO Identity Provider that the user authenticated with, to get this token | |
180 | auth_provider_id = attr.ib(type=str) | |
181 | ||
182 | ||
172 | 183 | class AuthHandler(BaseHandler): |
173 | 184 | SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 |
174 | 185 | |
325 | 336 | user is too high to proceed |
326 | 337 | |
327 | 338 | """ |
328 | ||
339 | if not requester.access_token_id: | |
340 | raise ValueError("Cannot validate a user without an access token") | |
329 | 341 | if self._ui_auth_session_timeout: |
330 | 342 | last_validated = await self.store.get_access_token_last_validated( |
331 | 343 | requester.access_token_id |
1163 | 1175 | return None |
1164 | 1176 | return user_id |
1165 | 1177 | |
1166 | async def validate_short_term_login_token_and_get_user_id(self, login_token: str): | |
1167 | auth_api = self.hs.get_auth() | |
1168 | user_id = None | |
1178 | async def validate_short_term_login_token( | |
1179 | self, login_token: str | |
1180 | ) -> LoginTokenAttributes: | |
1169 | 1181 | try: |
1170 | macaroon = pymacaroons.Macaroon.deserialize(login_token) | |
1171 | user_id = auth_api.get_user_id_from_macaroon(macaroon) | |
1172 | auth_api.validate_macaroon(macaroon, "login", user_id) | |
1182 | res = self.macaroon_gen.verify_short_term_login_token(login_token) | |
1173 | 1183 | except Exception: |
1174 | 1184 | raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) |
1175 | 1185 | |
1176 | await self.auth.check_auth_blocking(user_id) | |
1177 | return user_id | |
1186 | await self.auth.check_auth_blocking(res.user_id) | |
1187 | return res | |
1178 | 1188 | |
1179 | 1189 | async def delete_access_token(self, access_token: str): |
1180 | 1190 | """Invalidate a single access token |
1203 | 1213 | async def delete_access_tokens_for_user( |
1204 | 1214 | self, |
1205 | 1215 | user_id: str, |
1206 | except_token_id: Optional[str] = None, | |
1216 | except_token_id: Optional[int] = None, | |
1207 | 1217 | device_id: Optional[str] = None, |
1208 | 1218 | ): |
1209 | 1219 | """Invalidate access tokens belonging to a user |
1396 | 1406 | async def complete_sso_login( |
1397 | 1407 | self, |
1398 | 1408 | registered_user_id: str, |
1409 | auth_provider_id: str, | |
1399 | 1410 | request: Request, |
1400 | 1411 | client_redirect_url: str, |
1401 | 1412 | extra_attributes: Optional[JsonDict] = None, |
1405 | 1416 | |
1406 | 1417 | Args: |
1407 | 1418 | registered_user_id: The registered user ID to complete SSO login for. |
1419 | auth_provider_id: The id of the SSO Identity provider that was used for | |
1420 | login. This will be stored in the login token for future tracking in | |
1421 | prometheus metrics. | |
1408 | 1422 | request: The request to complete. |
1409 | 1423 | client_redirect_url: The URL to which to redirect the user at the end of the |
1410 | 1424 | process. |
1426 | 1440 | |
1427 | 1441 | self._complete_sso_login( |
1428 | 1442 | registered_user_id, |
1443 | auth_provider_id, | |
1429 | 1444 | request, |
1430 | 1445 | client_redirect_url, |
1431 | 1446 | extra_attributes, |
1436 | 1451 | def _complete_sso_login( |
1437 | 1452 | self, |
1438 | 1453 | registered_user_id: str, |
1454 | auth_provider_id: str, | |
1439 | 1455 | request: Request, |
1440 | 1456 | client_redirect_url: str, |
1441 | 1457 | extra_attributes: Optional[JsonDict] = None, |
1462 | 1478 | |
1463 | 1479 | # Create a login token |
1464 | 1480 | login_token = self.macaroon_gen.generate_short_term_login_token( |
1465 | registered_user_id | |
1481 | registered_user_id, auth_provider_id=auth_provider_id | |
1466 | 1482 | ) |
1467 | 1483 | |
1468 | 1484 | # Append the login token to the original redirect URL (i.e. with its query |
1568 | 1584 | return macaroon.serialize() |
1569 | 1585 | |
1570 | 1586 | def generate_short_term_login_token( |
1571 | self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) | |
1587 | self, | |
1588 | user_id: str, | |
1589 | auth_provider_id: str, | |
1590 | duration_in_ms: int = (2 * 60 * 1000), | |
1572 | 1591 | ) -> str: |
1573 | 1592 | macaroon = self._generate_base_macaroon(user_id) |
1574 | 1593 | macaroon.add_first_party_caveat("type = login") |
1575 | 1594 | now = self.hs.get_clock().time_msec() |
1576 | 1595 | expiry = now + duration_in_ms |
1577 | 1596 | macaroon.add_first_party_caveat("time < %d" % (expiry,)) |
1597 | macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,)) | |
1578 | 1598 | return macaroon.serialize() |
1599 | ||
1600 | def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: | |
1601 | """Verify a short-term-login macaroon | |
1602 | ||
1603 | Checks that the given token is a valid, unexpired short-term-login token | |
1604 | minted by this server. | |
1605 | ||
1606 | Args: | |
1607 | token: the login token to verify | |
1608 | ||
1609 | Returns: | |
1610 | the user_id that this token is valid for | |
1611 | ||
1612 | Raises: | |
1613 | MacaroonVerificationFailedException if the verification failed | |
1614 | """ | |
1615 | macaroon = pymacaroons.Macaroon.deserialize(token) | |
1616 | user_id = get_value_from_macaroon(macaroon, "user_id") | |
1617 | auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") | |
1618 | ||
1619 | v = pymacaroons.Verifier() | |
1620 | v.satisfy_exact("gen = 1") | |
1621 | v.satisfy_exact("type = login") | |
1622 | v.satisfy_general(lambda c: c.startswith("user_id = ")) | |
1623 | v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) | |
1624 | satisfy_expiry(v, self.hs.get_clock().time_msec) | |
1625 | v.verify(macaroon, self.hs.config.key.macaroon_secret_key) | |
1626 | ||
1627 | return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id) | |
1579 | 1628 | |
1580 | 1629 | def generate_delete_pusher_token(self, user_id: str) -> str: |
1581 | 1630 | macaroon = self._generate_base_macaroon(user_id) |
82 | 82 | # the SsoIdentityProvider protocol type. |
83 | 83 | self.idp_icon = None |
84 | 84 | self.idp_brand = None |
85 | self.unstable_idp_brand = None | |
85 | 86 | |
86 | 87 | self._sso_handler = hs.get_sso_handler() |
87 | 88 |
200 | 200 | or pdu.internal_metadata.is_outlier() |
201 | 201 | ) |
202 | 202 | if already_seen: |
203 | logger.debug("[%s %s]: Already seen pdu", room_id, event_id) | |
203 | logger.debug("Already seen pdu") | |
204 | 204 | return |
205 | 205 | |
206 | 206 | # do some initial sanity-checking of the event. In particular, make |
209 | 209 | try: |
210 | 210 | self._sanity_check_event(pdu) |
211 | 211 | except SynapseError as err: |
212 | logger.warning( | |
213 | "[%s %s] Received event failed sanity checks", room_id, event_id | |
214 | ) | |
212 | logger.warning("Received event failed sanity checks") | |
215 | 213 | raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) |
216 | 214 | |
217 | 215 | # If we are currently in the process of joining this room, then we |
218 | 216 | # queue up events for later processing. |
219 | 217 | if room_id in self.room_queues: |
220 | 218 | logger.info( |
221 | "[%s %s] Queuing PDU from %s for now: join in progress", | |
222 | room_id, | |
223 | event_id, | |
219 | "Queuing PDU from %s for now: join in progress", | |
224 | 220 | origin, |
225 | 221 | ) |
226 | 222 | self.room_queues[room_id].append((pdu, origin)) |
235 | 231 | is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) |
236 | 232 | if not is_in_room: |
237 | 233 | logger.info( |
238 | "[%s %s] Ignoring PDU from %s as we're not in the room", | |
239 | room_id, | |
240 | event_id, | |
234 | "Ignoring PDU from %s as we're not in the room", | |
241 | 235 | origin, |
242 | 236 | ) |
243 | 237 | return None |
249 | 243 | # We only backfill backwards to the min depth. |
250 | 244 | min_depth = await self.get_min_depth_for_context(pdu.room_id) |
251 | 245 | |
252 | logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) | |
246 | logger.debug("min_depth: %d", min_depth) | |
253 | 247 | |
254 | 248 | prevs = set(pdu.prev_event_ids()) |
255 | 249 | seen = await self.store.have_events_in_timeline(prevs) |
266 | 260 | # If we're missing stuff, ensure we only fetch stuff one |
267 | 261 | # at a time. |
268 | 262 | logger.info( |
269 | "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", | |
270 | room_id, | |
271 | event_id, | |
263 | "Acquiring room lock to fetch %d missing prev_events: %s", | |
272 | 264 | len(missing_prevs), |
273 | 265 | shortstr(missing_prevs), |
274 | 266 | ) |
275 | 267 | with (await self._room_pdu_linearizer.queue(pdu.room_id)): |
276 | 268 | logger.info( |
277 | "[%s %s] Acquired room lock to fetch %d missing prev_events", | |
278 | room_id, | |
279 | event_id, | |
269 | "Acquired room lock to fetch %d missing prev_events", | |
280 | 270 | len(missing_prevs), |
281 | 271 | ) |
282 | 272 | |
296 | 286 | |
297 | 287 | if not prevs - seen: |
298 | 288 | logger.info( |
299 | "[%s %s] Found all missing prev_events", | |
300 | room_id, | |
301 | event_id, | |
289 | "Found all missing prev_events", | |
302 | 290 | ) |
303 | 291 | |
304 | 292 | if prevs - seen: |
328 | 316 | |
329 | 317 | if sent_to_us_directly: |
330 | 318 | logger.warning( |
331 | "[%s %s] Rejecting: failed to fetch %d prev events: %s", | |
332 | room_id, | |
333 | event_id, | |
319 | "Rejecting: failed to fetch %d prev events: %s", | |
334 | 320 | len(prevs - seen), |
335 | 321 | shortstr(prevs - seen), |
336 | 322 | ) |
366 | 352 | # Ask the remote server for the states we don't |
367 | 353 | # know about |
368 | 354 | for p in prevs - seen: |
369 | logger.info( | |
370 | "Requesting state at missing prev_event %s", | |
371 | event_id, | |
372 | ) | |
355 | logger.info("Requesting state after missing prev_event %s", p) | |
373 | 356 | |
374 | 357 | with nested_logging_context(p): |
375 | 358 | # note that if any of the missing prevs share missing state or |
376 | 359 | # auth events, the requests to fetch those events are deduped |
377 | 360 | # by the get_pdu_cache in federation_client. |
378 | (remote_state, _,) = await self._get_state_for_room( | |
379 | origin, room_id, p, include_event_in_state=True | |
361 | remote_state = ( | |
362 | await self._get_state_after_missing_prev_event( | |
363 | origin, room_id, p | |
364 | ) | |
380 | 365 | ) |
381 | 366 | |
382 | 367 | remote_state_map = { |
413 | 398 | state = [event_map[e] for e in state_map.values()] |
414 | 399 | except Exception: |
415 | 400 | logger.warning( |
416 | "[%s %s] Error attempting to resolve state at missing " | |
417 | "prev_events", | |
418 | room_id, | |
419 | event_id, | |
401 | "Error attempting to resolve state at missing " "prev_events", | |
420 | 402 | exc_info=True, |
421 | 403 | ) |
422 | 404 | raise FederationError( |
453 | 435 | latest |= seen |
454 | 436 | |
455 | 437 | logger.info( |
456 | "[%s %s]: Requesting missing events between %s and %s", | |
457 | room_id, | |
458 | event_id, | |
438 | "Requesting missing events between %s and %s", | |
459 | 439 | shortstr(latest), |
460 | 440 | event_id, |
461 | 441 | ) |
522 | 502 | # We failed to get the missing events, but since we need to handle |
523 | 503 | # the case of `get_missing_events` not returning the necessary |
524 | 504 | # events anyway, it is safe to simply log the error and continue. |
525 | logger.warning( | |
526 | "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e | |
527 | ) | |
505 | logger.warning("Failed to get prev_events: %s", e) | |
528 | 506 | return |
529 | 507 | |
530 | 508 | logger.info( |
531 | "[%s %s]: Got %d prev_events: %s", | |
532 | room_id, | |
533 | event_id, | |
509 | "Got %d prev_events: %s", | |
534 | 510 | len(missing_events), |
535 | 511 | shortstr(missing_events), |
536 | 512 | ) |
541 | 517 | |
542 | 518 | for ev in missing_events: |
543 | 519 | logger.info( |
544 | "[%s %s] Handling received prev_event %s", | |
545 | room_id, | |
546 | event_id, | |
520 | "Handling received prev_event %s", | |
547 | 521 | ev.event_id, |
548 | 522 | ) |
549 | 523 | with nested_logging_context(ev.event_id): |
552 | 526 | except FederationError as e: |
553 | 527 | if e.code == 403: |
554 | 528 | logger.warning( |
555 | "[%s %s] Received prev_event %s failed history check.", | |
556 | room_id, | |
557 | event_id, | |
529 | "Received prev_event %s failed history check.", | |
558 | 530 | ev.event_id, |
559 | 531 | ) |
560 | 532 | else: |
565 | 537 | destination: str, |
566 | 538 | room_id: str, |
567 | 539 | event_id: str, |
568 | include_event_in_state: bool = False, | |
569 | 540 | ) -> Tuple[List[EventBase], List[EventBase]]: |
570 | 541 | """Requests all of the room state at a given event from a remote homeserver. |
571 | 542 | |
573 | 544 | destination: The remote homeserver to query for the state. |
574 | 545 | room_id: The id of the room we're interested in. |
575 | 546 | event_id: The id of the event we want the state at. |
576 | include_event_in_state: if true, the event itself will be included in the | |
577 | returned state event list. | |
578 | 547 | |
579 | 548 | Returns: |
580 | A list of events in the state, possibly including the event itself, and | |
549 | A list of events in the state, not including the event itself, and | |
581 | 550 | a list of events in the auth chain for the given event. |
582 | 551 | """ |
583 | 552 | ( |
588 | 557 | ) |
589 | 558 | |
590 | 559 | desired_events = set(state_event_ids + auth_event_ids) |
591 | ||
592 | if include_event_in_state: | |
593 | desired_events.add(event_id) | |
594 | 560 | |
595 | 561 | event_map = await self._get_events_from_store_or_dest( |
596 | 562 | destination, room_id, desired_events |
607 | 573 | remote_state = [ |
608 | 574 | event_map[e_id] for e_id in state_event_ids if e_id in event_map |
609 | 575 | ] |
610 | ||
611 | if include_event_in_state: | |
612 | remote_event = event_map.get(event_id) | |
613 | if not remote_event: | |
614 | raise Exception("Unable to get missing prev_event %s" % (event_id,)) | |
615 | if remote_event.is_state() and remote_event.rejected_reason is None: | |
616 | remote_state.append(remote_event) | |
617 | 576 | |
618 | 577 | auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] |
619 | 578 | auth_chain.sort(key=lambda e: e.depth) |
688 | 647 | |
689 | 648 | return fetched_events |
690 | 649 | |
650 | async def _get_state_after_missing_prev_event( | |
651 | self, | |
652 | destination: str, | |
653 | room_id: str, | |
654 | event_id: str, | |
655 | ) -> List[EventBase]: | |
656 | """Requests all of the room state at a given event from a remote homeserver. | |
657 | ||
658 | Args: | |
659 | destination: The remote homeserver to query for the state. | |
660 | room_id: The id of the room we're interested in. | |
661 | event_id: The id of the event we want the state at. | |
662 | ||
663 | Returns: | |
664 | A list of events in the state, including the event itself | |
665 | """ | |
666 | # TODO: This function is basically the same as _get_state_for_room. Can | |
667 | # we make backfill() use it, rather than having two code paths? I think the | |
668 | # only difference is that backfill() persists the prev events separately. | |
669 | ||
670 | ( | |
671 | state_event_ids, | |
672 | auth_event_ids, | |
673 | ) = await self.federation_client.get_room_state_ids( | |
674 | destination, room_id, event_id=event_id | |
675 | ) | |
676 | ||
677 | logger.debug( | |
678 | "state_ids returned %i state events, %i auth events", | |
679 | len(state_event_ids), | |
680 | len(auth_event_ids), | |
681 | ) | |
682 | ||
683 | # start by just trying to fetch the events from the store | |
684 | desired_events = set(state_event_ids) | |
685 | desired_events.add(event_id) | |
686 | logger.debug("Fetching %i events from cache/store", len(desired_events)) | |
687 | fetched_events = await self.store.get_events( | |
688 | desired_events, allow_rejected=True | |
689 | ) | |
690 | ||
691 | missing_desired_events = desired_events - fetched_events.keys() | |
692 | logger.debug( | |
693 | "We are missing %i events (got %i)", | |
694 | len(missing_desired_events), | |
695 | len(fetched_events), | |
696 | ) | |
697 | ||
698 | # We probably won't need most of the auth events, so let's just check which | |
699 | # we have for now, rather than thrashing the event cache with them all | |
700 | # unnecessarily. | |
701 | ||
702 | # TODO: we probably won't actually need all of the auth events, since we | |
703 | # already have a bunch of the state events. It would be nice if the | |
704 | # federation api gave us a way of finding out which we actually need. | |
705 | ||
706 | missing_auth_events = set(auth_event_ids) - fetched_events.keys() | |
707 | missing_auth_events.difference_update( | |
708 | await self.store.have_seen_events(missing_auth_events) | |
709 | ) | |
710 | logger.debug("We are also missing %i auth events", len(missing_auth_events)) | |
711 | ||
712 | missing_events = missing_desired_events | missing_auth_events | |
713 | logger.debug("Fetching %i events from remote", len(missing_events)) | |
714 | await self._get_events_and_persist( | |
715 | destination=destination, room_id=room_id, events=missing_events | |
716 | ) | |
717 | ||
718 | # we need to make sure we re-load from the database to get the rejected | |
719 | # state correct. | |
720 | fetched_events.update( | |
721 | (await self.store.get_events(missing_desired_events, allow_rejected=True)) | |
722 | ) | |
723 | ||
724 | # check for events which were in the wrong room. | |
725 | # | |
726 | # this can happen if a remote server claims that the state or | |
727 | # auth_events at an event in room A are actually events in room B | |
728 | ||
729 | bad_events = [ | |
730 | (event_id, event.room_id) | |
731 | for event_id, event in fetched_events.items() | |
732 | if event.room_id != room_id | |
733 | ] | |
734 | ||
735 | for bad_event_id, bad_room_id in bad_events: | |
736 | # This is a bogus situation, but since we may only discover it a long time | |
737 | # after it happened, we try our best to carry on, by just omitting the | |
738 | # bad events from the returned state set. | |
739 | logger.warning( | |
740 | "Remote server %s claims event %s in room %s is an auth/state " | |
741 | "event in room %s", | |
742 | destination, | |
743 | bad_event_id, | |
744 | bad_room_id, | |
745 | room_id, | |
746 | ) | |
747 | ||
748 | del fetched_events[bad_event_id] | |
749 | ||
750 | # if we couldn't get the prev event in question, that's a problem. | |
751 | remote_event = fetched_events.get(event_id) | |
752 | if not remote_event: | |
753 | raise Exception("Unable to get missing prev_event %s" % (event_id,)) | |
754 | ||
755 | # missing state at that event is a warning, not a blocker | |
756 | # XXX: this doesn't sound right? it means that we'll end up with incomplete | |
757 | # state. | |
758 | failed_to_fetch = desired_events - fetched_events.keys() | |
759 | if failed_to_fetch: | |
760 | logger.warning( | |
761 | "Failed to fetch missing state events for %s %s", | |
762 | event_id, | |
763 | failed_to_fetch, | |
764 | ) | |
765 | ||
766 | remote_state = [ | |
767 | fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events | |
768 | ] | |
769 | ||
770 | if remote_event.is_state() and remote_event.rejected_reason is None: | |
771 | remote_state.append(remote_event) | |
772 | ||
773 | return remote_state | |
774 | ||
691 | 775 | async def _process_received_pdu( |
692 | 776 | self, |
693 | 777 | origin: str, |
706 | 790 | (ie, we are missing one or more prev_events), the resolved state at the |
707 | 791 | event |
708 | 792 | """ |
709 | room_id = event.room_id | |
710 | event_id = event.event_id | |
711 | ||
712 | logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) | |
793 | logger.debug("Processing event: %s", event) | |
713 | 794 | |
714 | 795 | try: |
715 | 796 | await self._handle_new_event(origin, event, state=state) |
870 | 951 | destination=dest, |
871 | 952 | room_id=room_id, |
872 | 953 | event_id=e_id, |
873 | include_event_in_state=False, | |
874 | 954 | ) |
875 | 955 | auth_events.update({a.event_id: a for a in auth}) |
876 | 956 | auth_events.update({s.event_id: s for s in state}) |
1316 | 1396 | async def on_event_auth(self, event_id: str) -> List[EventBase]: |
1317 | 1397 | event = await self.store.get_event(event_id) |
1318 | 1398 | auth = await self.store.get_auth_chain( |
1319 | list(event.auth_event_ids()), include_given=True | |
1399 | event.room_id, list(event.auth_event_ids()), include_given=True | |
1320 | 1400 | ) |
1321 | 1401 | return list(auth) |
1322 | 1402 | |
1579 | 1659 | prev_state_ids = await context.get_prev_state_ids() |
1580 | 1660 | |
1581 | 1661 | state_ids = list(prev_state_ids.values()) |
1582 | auth_chain = await self.store.get_auth_chain(state_ids) | |
1662 | auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) | |
1583 | 1663 | |
1584 | 1664 | state = await self.store.get_events(list(prev_state_ids.values())) |
1585 | 1665 | |
2218 | 2298 | |
2219 | 2299 | # Now get the current auth_chain for the event. |
2220 | 2300 | local_auth_chain = await self.store.get_auth_chain( |
2221 | list(event.auth_event_ids()), include_given=True | |
2301 | room_id, list(event.auth_event_ids()), include_given=True | |
2222 | 2302 | ) |
2223 | 2303 | |
2224 | 2304 | # TODO: Check if we would now reject event_id. If so we need to tell |
47 | 47 | self.clock = hs.get_clock() |
48 | 48 | self.validator = EventValidator() |
49 | 49 | self.snapshot_cache = ResponseCache( |
50 | hs, "initial_sync_cache" | |
50 | hs.get_clock(), "initial_sync_cache" | |
51 | 51 | ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] |
52 | 52 | self._event_serializer = hs.get_event_client_serializer() |
53 | 53 | self.storage = hs.get_storage() |
0 | 0 | # -*- coding: utf-8 -*- |
1 | 1 | # Copyright 2020 Quentin Gliech |
2 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
2 | 3 | # |
3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 5 | # you may not use this file except in compliance with the License. |
13 | 14 | # limitations under the License. |
14 | 15 | import inspect |
15 | 16 | import logging |
16 | from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar | |
17 | from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union | |
17 | 18 | from urllib.parse import urlencode |
18 | 19 | |
19 | 20 | import attr |
20 | 21 | import pymacaroons |
21 | 22 | from authlib.common.security import generate_token |
22 | from authlib.jose import JsonWebToken | |
23 | from authlib.jose import JsonWebToken, jwt | |
23 | 24 | from authlib.oauth2.auth import ClientAuth |
24 | 25 | from authlib.oauth2.rfc6749.parameters import prepare_grant_uri |
25 | 26 | from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo |
27 | 28 | from jinja2 import Environment, Template |
28 | 29 | from pymacaroons.exceptions import ( |
29 | 30 | MacaroonDeserializationException, |
31 | MacaroonInitException, | |
30 | 32 | MacaroonInvalidSignatureException, |
31 | 33 | ) |
32 | 34 | from typing_extensions import TypedDict |
33 | 35 | |
34 | 36 | from twisted.web.client import readBody |
37 | from twisted.web.http_headers import Headers | |
35 | 38 | |
36 | 39 | from synapse.config import ConfigError |
37 | from synapse.config.oidc_config import OidcProviderConfig | |
40 | from synapse.config.oidc_config import ( | |
41 | OidcProviderClientSecretJwtKey, | |
42 | OidcProviderConfig, | |
43 | ) | |
38 | 44 | from synapse.handlers.sso import MappingException, UserAttributes |
39 | 45 | from synapse.http.site import SynapseRequest |
40 | 46 | from synapse.logging.context import make_deferred_yieldable |
41 | 47 | from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart |
42 | from synapse.util import json_decoder | |
48 | from synapse.util import Clock, json_decoder | |
43 | 49 | from synapse.util.caches.cached_call import RetryOnExceptionCachedCall |
50 | from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry | |
44 | 51 | |
45 | 52 | if TYPE_CHECKING: |
46 | 53 | from synapse.server import HomeServer |
210 | 217 | session_data = self._token_generator.verify_oidc_session_token( |
211 | 218 | session, state |
212 | 219 | ) |
213 | except (MacaroonDeserializationException, ValueError) as e: | |
220 | except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e: | |
214 | 221 | logger.exception("Invalid session for OIDC callback") |
215 | 222 | self._sso_handler.render_error(request, "invalid_session", str(e)) |
216 | 223 | return |
274 | 281 | |
275 | 282 | self._scopes = provider.scopes |
276 | 283 | self._user_profile_method = provider.user_profile_method |
284 | ||
285 | client_secret = None # type: Union[None, str, JwtClientSecret] | |
286 | if provider.client_secret: | |
287 | client_secret = provider.client_secret | |
288 | elif provider.client_secret_jwt_key: | |
289 | client_secret = JwtClientSecret( | |
290 | provider.client_secret_jwt_key, | |
291 | provider.client_id, | |
292 | provider.issuer, | |
293 | hs.get_clock(), | |
294 | ) | |
295 | ||
277 | 296 | self._client_auth = ClientAuth( |
278 | 297 | provider.client_id, |
279 | provider.client_secret, | |
298 | client_secret, | |
280 | 299 | provider.client_auth_method, |
281 | 300 | ) # type: ClientAuth |
282 | 301 | self._client_auth_method = provider.client_auth_method |
310 | 329 | |
311 | 330 | # optional brand identifier for this auth provider |
312 | 331 | self.idp_brand = provider.idp_brand |
332 | ||
333 | # Optional brand identifier for the unstable API (see MSC2858). | |
334 | self.unstable_idp_brand = provider.unstable_idp_brand | |
313 | 335 | |
314 | 336 | self._sso_handler = hs.get_sso_handler() |
315 | 337 | |
520 | 542 | """ |
521 | 543 | metadata = await self.load_metadata() |
522 | 544 | token_endpoint = metadata.get("token_endpoint") |
523 | headers = { | |
545 | raw_headers = { | |
524 | 546 | "Content-Type": "application/x-www-form-urlencoded", |
525 | 547 | "User-Agent": self._http_client.user_agent, |
526 | 548 | "Accept": "application/json", |
534 | 556 | body = urlencode(args, True) |
535 | 557 | |
536 | 558 | # Fill the body/headers with credentials |
537 | uri, headers, body = self._client_auth.prepare( | |
538 | method="POST", uri=token_endpoint, headers=headers, body=body | |
539 | ) | |
540 | headers = {k: [v] for (k, v) in headers.items()} | |
559 | uri, raw_headers, body = self._client_auth.prepare( | |
560 | method="POST", uri=token_endpoint, headers=raw_headers, body=body | |
561 | ) | |
562 | headers = Headers({k: [v] for (k, v) in raw_headers.items()}) | |
541 | 563 | |
542 | 564 | # Do the actual request |
543 | 565 | # We're not using the SimpleHttpClient util methods as we don't want to |
744 | 766 | idp_id=self.idp_id, |
745 | 767 | nonce=nonce, |
746 | 768 | client_redirect_url=client_redirect_url.decode(), |
747 | ui_auth_session_id=ui_auth_session_id, | |
769 | ui_auth_session_id=ui_auth_session_id or "", | |
748 | 770 | ), |
749 | 771 | ) |
750 | 772 | |
975 | 997 | return str(remote_user_id) |
976 | 998 | |
977 | 999 | |
1000 | # number of seconds a newly-generated client secret should be valid for | |
1001 | CLIENT_SECRET_VALIDITY_SECONDS = 3600 | |
1002 | ||
1003 | # minimum remaining validity on a client secret before we should generate a new one | |
1004 | CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600 | |
1005 | ||
1006 | ||
1007 | class JwtClientSecret: | |
1008 | """A class which generates a new client secret on demand, based on a JWK | |
1009 | ||
1010 | This implementation is designed to comply with the requirements for Apple Sign in: | |
1011 | https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048 | |
1012 | ||
1013 | It looks like those requirements are based on https://tools.ietf.org/html/rfc7523, | |
1014 | but it's worth noting that we still put the generated secret in the "client_secret" | |
1015 | field (or rather, whereever client_auth_method puts it) rather than in a | |
1016 | client_assertion field in the body as that RFC seems to require. | |
1017 | """ | |
1018 | ||
1019 | def __init__( | |
1020 | self, | |
1021 | key: OidcProviderClientSecretJwtKey, | |
1022 | oauth_client_id: str, | |
1023 | oauth_issuer: str, | |
1024 | clock: Clock, | |
1025 | ): | |
1026 | self._key = key | |
1027 | self._oauth_client_id = oauth_client_id | |
1028 | self._oauth_issuer = oauth_issuer | |
1029 | self._clock = clock | |
1030 | self._cached_secret = b"" | |
1031 | self._cached_secret_replacement_time = 0 | |
1032 | ||
1033 | def __str__(self): | |
1034 | # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls | |
1035 | # encode_client_secret_basic, which calls "{}".format(secret), which ends up | |
1036 | # here. | |
1037 | return self._get_secret().decode("ascii") | |
1038 | ||
1039 | def __bytes__(self): | |
1040 | # if client_auth_method is client_secret_post, then ClientAuth.prepare calls | |
1041 | # encode_client_secret_post, which ends up here. | |
1042 | return self._get_secret() | |
1043 | ||
1044 | def _get_secret(self) -> bytes: | |
1045 | now = self._clock.time() | |
1046 | ||
1047 | # if we have enough validity on our existing secret, use it | |
1048 | if now < self._cached_secret_replacement_time: | |
1049 | return self._cached_secret | |
1050 | ||
1051 | issued_at = int(now) | |
1052 | expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS | |
1053 | ||
1054 | # we copy the configured header because jwt.encode modifies it. | |
1055 | header = dict(self._key.jwt_header) | |
1056 | ||
1057 | # see https://tools.ietf.org/html/rfc7523#section-3 | |
1058 | payload = { | |
1059 | "sub": self._oauth_client_id, | |
1060 | "aud": self._oauth_issuer, | |
1061 | "iat": issued_at, | |
1062 | "exp": expires_at, | |
1063 | **self._key.jwt_payload, | |
1064 | } | |
1065 | logger.info( | |
1066 | "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload | |
1067 | ) | |
1068 | self._cached_secret = jwt.encode(header, payload, self._key.key) | |
1069 | self._cached_secret_replacement_time = ( | |
1070 | expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS | |
1071 | ) | |
1072 | return self._cached_secret | |
1073 | ||
1074 | ||
978 | 1075 | class OidcSessionTokenGenerator: |
979 | 1076 | """Methods for generating and checking OIDC Session cookies.""" |
980 | 1077 | |
1019 | 1116 | macaroon.add_first_party_caveat( |
1020 | 1117 | "client_redirect_url = %s" % (session_data.client_redirect_url,) |
1021 | 1118 | ) |
1022 | if session_data.ui_auth_session_id: | |
1023 | macaroon.add_first_party_caveat( | |
1024 | "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,) | |
1025 | ) | |
1119 | macaroon.add_first_party_caveat( | |
1120 | "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,) | |
1121 | ) | |
1026 | 1122 | now = self._clock.time_msec() |
1027 | 1123 | expiry = now + duration_in_ms |
1028 | 1124 | macaroon.add_first_party_caveat("time < %d" % (expiry,)) |
1045 | 1141 | The data extracted from the session cookie |
1046 | 1142 | |
1047 | 1143 | Raises: |
1048 | ValueError if an expected caveat is missing from the macaroon. | |
1144 | KeyError if an expected caveat is missing from the macaroon. | |
1049 | 1145 | """ |
1050 | 1146 | macaroon = pymacaroons.Macaroon.deserialize(session) |
1051 | 1147 | |
1056 | 1152 | v.satisfy_general(lambda c: c.startswith("nonce = ")) |
1057 | 1153 | v.satisfy_general(lambda c: c.startswith("idp_id = ")) |
1058 | 1154 | v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) |
1059 | # Sometimes there's a UI auth session ID, it seems to be OK to attempt | |
1060 | # to always satisfy this. | |
1061 | 1155 | v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) |
1062 | v.satisfy_general(self._verify_expiry) | |
1156 | satisfy_expiry(v, self._clock.time_msec) | |
1063 | 1157 | |
1064 | 1158 | v.verify(macaroon, self._macaroon_secret_key) |
1065 | 1159 | |
1066 | 1160 | # Extract the session data from the token. |
1067 | nonce = self._get_value_from_macaroon(macaroon, "nonce") | |
1068 | idp_id = self._get_value_from_macaroon(macaroon, "idp_id") | |
1069 | client_redirect_url = self._get_value_from_macaroon( | |
1070 | macaroon, "client_redirect_url" | |
1071 | ) | |
1072 | try: | |
1073 | ui_auth_session_id = self._get_value_from_macaroon( | |
1074 | macaroon, "ui_auth_session_id" | |
1075 | ) # type: Optional[str] | |
1076 | except ValueError: | |
1077 | ui_auth_session_id = None | |
1078 | ||
1161 | nonce = get_value_from_macaroon(macaroon, "nonce") | |
1162 | idp_id = get_value_from_macaroon(macaroon, "idp_id") | |
1163 | client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url") | |
1164 | ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id") | |
1079 | 1165 | return OidcSessionData( |
1080 | 1166 | nonce=nonce, |
1081 | 1167 | idp_id=idp_id, |
1083 | 1169 | ui_auth_session_id=ui_auth_session_id, |
1084 | 1170 | ) |
1085 | 1171 | |
1086 | def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: | |
1087 | """Extracts a caveat value from a macaroon token. | |
1088 | ||
1089 | Args: | |
1090 | macaroon: the token | |
1091 | key: the key of the caveat to extract | |
1092 | ||
1093 | Returns: | |
1094 | The extracted value | |
1095 | ||
1096 | Raises: | |
1097 | ValueError: if the caveat was not in the macaroon | |
1098 | """ | |
1099 | prefix = key + " = " | |
1100 | for caveat in macaroon.caveats: | |
1101 | if caveat.caveat_id.startswith(prefix): | |
1102 | return caveat.caveat_id[len(prefix) :] | |
1103 | raise ValueError("No %s caveat in macaroon" % (key,)) | |
1104 | ||
1105 | def _verify_expiry(self, caveat: str) -> bool: | |
1106 | prefix = "time < " | |
1107 | if not caveat.startswith(prefix): | |
1108 | return False | |
1109 | expiry = int(caveat[len(prefix) :]) | |
1110 | now = self._clock.time_msec() | |
1111 | return now < expiry | |
1112 | ||
1113 | 1172 | |
1114 | 1173 | @attr.s(frozen=True, slots=True) |
1115 | 1174 | class OidcSessionData: |
1124 | 1183 | # The URL the client gave when it initiated the flow. ("" if this is a UI Auth) |
1125 | 1184 | client_redirect_url = attr.ib(type=str) |
1126 | 1185 | |
1127 | # The session ID of the ongoing UI Auth (None if this is a login) | |
1128 | ui_auth_session_id = attr.ib(type=Optional[str], default=None) | |
1186 | # The session ID of the ongoing UI Auth ("" if this is a login) | |
1187 | ui_auth_session_id = attr.ib(type=str) | |
1129 | 1188 | |
1130 | 1189 | |
1131 | 1190 | UserAttributeDict = TypedDict( |
284 | 284 | except Exception: |
285 | 285 | f = Failure() |
286 | 286 | logger.error( |
287 | "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) | |
287 | "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore | |
288 | 288 | ) |
289 | 289 | self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED |
290 | 290 | finally: |
15 | 15 | """Contains functions for registering clients.""" |
16 | 16 | |
17 | 17 | import logging |
18 | from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple | |
18 | from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple | |
19 | ||
20 | from prometheus_client import Counter | |
19 | 21 | |
20 | 22 | from synapse import types |
21 | 23 | from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType |
40 | 42 | logger = logging.getLogger(__name__) |
41 | 43 | |
42 | 44 | |
45 | registration_counter = Counter( | |
46 | "synapse_user_registrations_total", | |
47 | "Number of new users registered (since restart)", | |
48 | ["guest", "shadow_banned", "auth_provider"], | |
49 | ) | |
50 | ||
51 | login_counter = Counter( | |
52 | "synapse_user_logins_total", | |
53 | "Number of user logins (since restart)", | |
54 | ["guest", "auth_provider"], | |
55 | ) | |
56 | ||
57 | ||
43 | 58 | class RegistrationHandler(BaseHandler): |
44 | 59 | def __init__(self, hs: "HomeServer"): |
45 | 60 | super().__init__(hs) |
66 | 81 | ) |
67 | 82 | else: |
68 | 83 | self.device_handler = hs.get_device_handler() |
84 | self._register_device_client = self.register_device_inner | |
69 | 85 | self.pusher_pool = hs.get_pusherpool() |
70 | 86 | |
71 | 87 | self.session_lifetime = hs.config.session_lifetime |
155 | 171 | bind_emails: Iterable[str] = [], |
156 | 172 | by_admin: bool = False, |
157 | 173 | user_agent_ips: Optional[List[Tuple[str, str]]] = None, |
174 | auth_provider_id: Optional[str] = None, | |
158 | 175 | ) -> str: |
159 | 176 | """Registers a new client on the server. |
160 | 177 | |
180 | 197 | admin api, otherwise False. |
181 | 198 | user_agent_ips: Tuples of IP addresses and user-agents used |
182 | 199 | during the registration process. |
200 | auth_provider_id: The SSO IdP the user used, if any. | |
183 | 201 | Returns: |
184 | The registere user_id. | |
202 | The registered user_id. | |
185 | 203 | Raises: |
186 | 204 | SynapseError if there was a problem registering. |
187 | 205 | """ |
191 | 209 | threepid, |
192 | 210 | localpart, |
193 | 211 | user_agent_ips or [], |
212 | auth_provider_id=auth_provider_id, | |
194 | 213 | ) |
195 | 214 | |
196 | 215 | if result == RegistrationBehaviour.DENY: |
278 | 297 | except SynapseError: |
279 | 298 | # if user id is taken, just generate another |
280 | 299 | fail_count += 1 |
300 | ||
301 | registration_counter.labels( | |
302 | guest=make_guest, | |
303 | shadow_banned=shadow_banned, | |
304 | auth_provider=(auth_provider_id or ""), | |
305 | ).inc() | |
281 | 306 | |
282 | 307 | if not self.hs.config.user_consent_at_registration: |
283 | 308 | if not self.hs.config.auto_join_rooms_for_guests and make_guest: |
637 | 662 | initial_display_name: Optional[str], |
638 | 663 | is_guest: bool = False, |
639 | 664 | is_appservice_ghost: bool = False, |
665 | auth_provider_id: Optional[str] = None, | |
640 | 666 | ) -> Tuple[str, str]: |
641 | 667 | """Register a device for a user and generate an access token. |
642 | 668 | |
647 | 673 | device_id: The device ID to check, or None to generate a new one. |
648 | 674 | initial_display_name: An optional display name for the device. |
649 | 675 | is_guest: Whether this is a guest account |
650 | ||
676 | auth_provider_id: The SSO IdP the user used, if any (just used for the | |
677 | prometheus metrics). | |
651 | 678 | Returns: |
652 | 679 | Tuple of device ID and access token |
653 | 680 | """ |
654 | ||
655 | if self.hs.config.worker_app: | |
656 | r = await self._register_device_client( | |
657 | user_id=user_id, | |
658 | device_id=device_id, | |
659 | initial_display_name=initial_display_name, | |
660 | is_guest=is_guest, | |
661 | is_appservice_ghost=is_appservice_ghost, | |
662 | ) | |
663 | return r["device_id"], r["access_token"] | |
664 | ||
681 | res = await self._register_device_client( | |
682 | user_id=user_id, | |
683 | device_id=device_id, | |
684 | initial_display_name=initial_display_name, | |
685 | is_guest=is_guest, | |
686 | is_appservice_ghost=is_appservice_ghost, | |
687 | ) | |
688 | ||
689 | login_counter.labels( | |
690 | guest=is_guest, | |
691 | auth_provider=(auth_provider_id or ""), | |
692 | ).inc() | |
693 | ||
694 | return res["device_id"], res["access_token"] | |
695 | ||
696 | async def register_device_inner( | |
697 | self, | |
698 | user_id: str, | |
699 | device_id: Optional[str], | |
700 | initial_display_name: Optional[str], | |
701 | is_guest: bool = False, | |
702 | is_appservice_ghost: bool = False, | |
703 | ) -> Dict[str, str]: | |
704 | """Helper for register_device | |
705 | ||
706 | Does the bits that need doing on the main process. Not for use outside this | |
707 | class and RegisterDeviceReplicationServlet. | |
708 | """ | |
709 | assert not self.hs.config.worker_app | |
665 | 710 | valid_until_ms = None |
666 | 711 | if self.session_lifetime is not None: |
667 | 712 | if is_guest: |
686 | 731 | is_appservice_ghost=is_appservice_ghost, |
687 | 732 | ) |
688 | 733 | |
689 | return (registered_device_id, access_token) | |
734 | return {"device_id": registered_device_id, "access_token": access_token} | |
690 | 735 | |
691 | 736 | async def post_registration_actions( |
692 | 737 | self, user_id: str, auth_result: dict, access_token: Optional[str] |
120 | 120 | # succession, only process the first attempt and return its result to |
121 | 121 | # subsequent requests |
122 | 122 | self._upgrade_response_cache = ResponseCache( |
123 | hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS | |
123 | hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS | |
124 | 124 | ) # type: ResponseCache[Tuple[str, str]] |
125 | 125 | self._server_notices_mxid = hs.config.server_notices_mxid |
126 | 126 |
43 | 43 | super().__init__(hs) |
44 | 44 | self.enable_room_list_search = hs.config.enable_room_list_search |
45 | 45 | self.response_cache = ResponseCache( |
46 | hs, "room_list" | |
46 | hs.get_clock(), "room_list" | |
47 | 47 | ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]] |
48 | 48 | self.remote_response_cache = ResponseCache( |
49 | hs, "remote_room_list", timeout_ms=30 * 1000 | |
49 | hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 | |
50 | 50 | ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] |
51 | 51 | |
52 | 52 | async def get_local_public_room_list( |
80 | 80 | # the SsoIdentityProvider protocol type. |
81 | 81 | self.idp_icon = None |
82 | 82 | self.idp_brand = None |
83 | self.unstable_idp_brand = None | |
83 | 84 | |
84 | 85 | # a map from saml session id to Saml2SessionData object |
85 | 86 | self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] |
97 | 97 | """Optional branding identifier""" |
98 | 98 | return None |
99 | 99 | |
100 | @property | |
101 | def unstable_idp_brand(self) -> Optional[str]: | |
102 | """Optional brand identifier for the unstable API (see MSC2858).""" | |
103 | return None | |
104 | ||
100 | 105 | @abc.abstractmethod |
101 | 106 | async def handle_redirect_request( |
102 | 107 | self, |
455 | 460 | |
456 | 461 | await self._auth_handler.complete_sso_login( |
457 | 462 | user_id, |
463 | auth_provider_id, | |
458 | 464 | request, |
459 | 465 | client_redirect_url, |
460 | 466 | extra_login_attributes, |
604 | 610 | default_display_name=attributes.display_name, |
605 | 611 | bind_emails=attributes.emails, |
606 | 612 | user_agent_ips=[(user_agent, ip_address)], |
613 | auth_provider_id=auth_provider_id, | |
607 | 614 | ) |
608 | 615 | |
609 | 616 | await self._store.record_user_external_id( |
885 | 892 | |
886 | 893 | await self._auth_handler.complete_sso_login( |
887 | 894 | user_id, |
895 | session.auth_provider_id, | |
888 | 896 | request, |
889 | 897 | session.client_redirect_url, |
890 | 898 | session.extra_login_attributes, |
243 | 243 | self.event_sources = hs.get_event_sources() |
244 | 244 | self.clock = hs.get_clock() |
245 | 245 | self.response_cache = ResponseCache( |
246 | hs, "sync" | |
246 | hs.get_clock(), "sync" | |
247 | 247 | ) # type: ResponseCache[Tuple[Any, ...]] |
248 | 248 | self.state = hs.get_state_handler() |
249 | 249 | self.auth = hs.get_auth() |
38 | 38 | from OpenSSL import SSL |
39 | 39 | from OpenSSL.SSL import VERIFY_NONE |
40 | 40 | from twisted.internet import defer, error as twisted_error, protocol, ssl |
41 | from twisted.internet.address import IPv4Address, IPv6Address | |
41 | 42 | from twisted.internet.interfaces import ( |
42 | 43 | IAddress, |
43 | 44 | IHostResolution, |
44 | 45 | IReactorPluggableNameResolver, |
45 | 46 | IResolutionReceiver, |
47 | ITCPTransport, | |
46 | 48 | ) |
49 | from twisted.internet.protocol import connectionDone | |
47 | 50 | from twisted.internet.task import Cooperator |
48 | 51 | from twisted.python.failure import Failure |
49 | 52 | from twisted.web._newclient import ResponseDone |
55 | 58 | ) |
56 | 59 | from twisted.web.http import PotentialDataLoss |
57 | 60 | from twisted.web.http_headers import Headers |
58 | from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse | |
61 | from twisted.web.iweb import ( | |
62 | UNKNOWN_LENGTH, | |
63 | IAgent, | |
64 | IBodyProducer, | |
65 | IPolicyForHTTPS, | |
66 | IResponse, | |
67 | ) | |
59 | 68 | |
60 | 69 | from synapse.api.errors import Codes, HttpResponseException, SynapseError |
61 | 70 | from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri |
62 | 71 | from synapse.http.proxyagent import ProxyAgent |
63 | 72 | from synapse.logging.context import make_deferred_yieldable |
64 | 73 | from synapse.logging.opentracing import set_tag, start_active_span, tags |
74 | from synapse.types import ISynapseReactor | |
65 | 75 | from synapse.util import json_decoder |
66 | 76 | from synapse.util.async_helpers import timeout_deferred |
67 | 77 | |
149 | 159 | def resolveHostName( |
150 | 160 | self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 |
151 | 161 | ) -> IResolutionReceiver: |
152 | ||
153 | r = recv() | |
154 | 162 | addresses = [] # type: List[IAddress] |
155 | 163 | |
156 | 164 | def _callback() -> None: |
157 | r.resolutionBegan(None) | |
158 | ||
159 | 165 | has_bad_ip = False |
160 | for i in addresses: | |
161 | ip_address = IPAddress(i.host) | |
166 | for address in addresses: | |
167 | # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups | |
168 | # should go through this path. | |
169 | if not isinstance(address, (IPv4Address, IPv6Address)): | |
170 | continue | |
171 | ||
172 | ip_address = IPAddress(address.host) | |
162 | 173 | |
163 | 174 | if check_against_blacklist( |
164 | 175 | ip_address, self._ip_whitelist, self._ip_blacklist |
173 | 184 | # request, but all we can really do from here is claim that there were no |
174 | 185 | # valid results. |
175 | 186 | if not has_bad_ip: |
176 | for i in addresses: | |
177 | r.addressResolved(i) | |
178 | r.resolutionComplete() | |
187 | for address in addresses: | |
188 | recv.addressResolved(address) | |
189 | recv.resolutionComplete() | |
179 | 190 | |
180 | 191 | @provider(IResolutionReceiver) |
181 | 192 | class EndpointReceiver: |
182 | 193 | @staticmethod |
183 | 194 | def resolutionBegan(resolutionInProgress: IHostResolution) -> None: |
184 | pass | |
195 | recv.resolutionBegan(resolutionInProgress) | |
185 | 196 | |
186 | 197 | @staticmethod |
187 | 198 | def addressResolved(address: IAddress) -> None: |
195 | 206 | EndpointReceiver, hostname, portNumber=portNumber |
196 | 207 | ) |
197 | 208 | |
198 | return r | |
199 | ||
200 | ||
201 | @implementer(IReactorPluggableNameResolver) | |
209 | return recv | |
210 | ||
211 | ||
212 | @implementer(ISynapseReactor) | |
202 | 213 | class BlacklistingReactorWrapper: |
203 | 214 | """ |
204 | 215 | A Reactor wrapper which will prevent DNS resolution to blacklisted IP |
323 | 334 | # filters out blacklisted IP addresses, to prevent DNS rebinding. |
324 | 335 | self.reactor = BlacklistingReactorWrapper( |
325 | 336 | hs.get_reactor(), self._ip_whitelist, self._ip_blacklist |
326 | ) | |
337 | ) # type: ISynapseReactor | |
327 | 338 | else: |
328 | 339 | self.reactor = hs.get_reactor() |
329 | 340 | |
344 | 355 | contextFactory=self.hs.get_http_client_context_factory(), |
345 | 356 | pool=pool, |
346 | 357 | use_proxy=use_proxy, |
347 | ) | |
358 | ) # type: IAgent | |
348 | 359 | |
349 | 360 | if self._ip_blacklist: |
350 | 361 | # If we have an IP blacklist, we then install the blacklisting Agent |
750 | 761 | class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): |
751 | 762 | """A protocol which immediately errors upon receiving data.""" |
752 | 763 | |
764 | transport = None # type: Optional[ITCPTransport] | |
765 | ||
753 | 766 | def __init__(self, deferred: defer.Deferred): |
754 | 767 | self.deferred = deferred |
755 | 768 | |
761 | 774 | self.deferred.errback(BodyExceededMaxSize()) |
762 | 775 | # Close the connection (forcefully) since all the data will get |
763 | 776 | # discarded anyway. |
777 | assert self.transport is not None | |
764 | 778 | self.transport.abortConnection() |
765 | 779 | |
766 | 780 | def dataReceived(self, data: bytes) -> None: |
767 | 781 | self._maybe_fail() |
768 | 782 | |
769 | def connectionLost(self, reason: Failure) -> None: | |
783 | def connectionLost(self, reason: Failure = connectionDone) -> None: | |
770 | 784 | self._maybe_fail() |
771 | 785 | |
772 | 786 | |
773 | 787 | class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): |
774 | 788 | """A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" |
789 | ||
790 | transport = None # type: Optional[ITCPTransport] | |
775 | 791 | |
776 | 792 | def __init__( |
777 | 793 | self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] |
795 | 811 | self.deferred.errback(BodyExceededMaxSize()) |
796 | 812 | # Close the connection (forcefully) since all the data will get |
797 | 813 | # discarded anyway. |
814 | assert self.transport is not None | |
798 | 815 | self.transport.abortConnection() |
799 | 816 | |
800 | def connectionLost(self, reason: Failure) -> None: | |
817 | def connectionLost(self, reason: Failure = connectionDone) -> None: | |
801 | 818 | # If the maximum size was already exceeded, there's nothing to do. |
802 | 819 | if self.deferred.called: |
803 | 820 | return |
866 | 883 | return query_str.encode("utf8") |
867 | 884 | |
868 | 885 | |
886 | @implementer(IPolicyForHTTPS) | |
869 | 887 | class InsecureInterceptableContextFactory(ssl.ContextFactory): |
870 | 888 | """ |
871 | 889 | Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain. |
34 | 34 | from synapse.http.federation.srv_resolver import Server, SrvResolver |
35 | 35 | from synapse.http.federation.well_known_resolver import WellKnownResolver |
36 | 36 | from synapse.logging.context import make_deferred_yieldable, run_in_background |
37 | from synapse.types import ISynapseReactor | |
37 | 38 | from synapse.util import Clock |
38 | 39 | |
39 | 40 | logger = logging.getLogger(__name__) |
67 | 68 | |
68 | 69 | def __init__( |
69 | 70 | self, |
70 | reactor: IReactorCore, | |
71 | reactor: ISynapseReactor, | |
71 | 72 | tls_client_options_factory: Optional[FederationPolicyForHTTPS], |
72 | 73 | user_agent: bytes, |
73 | 74 | ip_blacklist: IPSet, |
321 | 321 | |
322 | 322 | def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: |
323 | 323 | cache_controls = {} |
324 | for hdr in headers.getRawHeaders(b"cache-control", []): | |
324 | cache_control_headers = headers.getRawHeaders(b"cache-control") or [] | |
325 | for hdr in cache_control_headers: | |
325 | 326 | for directive in hdr.split(b","): |
326 | 327 | splits = [x.strip() for x in directive.split(b"=", 1)] |
327 | 328 | k = splits[0].lower() |
58 | 58 | start_active_span, |
59 | 59 | tags, |
60 | 60 | ) |
61 | from synapse.types import JsonDict | |
61 | from synapse.types import ISynapseReactor, JsonDict | |
62 | 62 | from synapse.util import json_decoder |
63 | 63 | from synapse.util.async_helpers import timeout_deferred |
64 | 64 | from synapse.util.metrics import Measure |
236 | 236 | # addresses, to prevent DNS rebinding. |
237 | 237 | self.reactor = BlacklistingReactorWrapper( |
238 | 238 | hs.get_reactor(), None, hs.config.federation_ip_range_blacklist |
239 | ) | |
239 | ) # type: ISynapseReactor | |
240 | 240 | |
241 | 241 | user_agent = hs.version_string |
242 | 242 | if hs.config.user_agent_suffix: |
243 | 243 | user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) |
244 | 244 | user_agent = user_agent.encode("ascii") |
245 | 245 | |
246 | self.agent = MatrixFederationAgent( | |
246 | federation_agent = MatrixFederationAgent( | |
247 | 247 | self.reactor, |
248 | 248 | tls_client_options_factory, |
249 | 249 | user_agent, |
253 | 253 | # Use a BlacklistingAgentWrapper to prevent circumventing the IP |
254 | 254 | # blacklist via IP literals in server names |
255 | 255 | self.agent = BlacklistingAgentWrapper( |
256 | self.agent, | |
256 | federation_agent, | |
257 | 257 | ip_blacklist=hs.config.federation_ip_range_blacklist, |
258 | 258 | ) |
259 | 259 | |
533 | 533 | response.code, response_phrase, body |
534 | 534 | ) |
535 | 535 | |
536 | # Retry if the error is a 429 (Too Many Requests), | |
537 | # otherwise just raise a standard HttpResponseException | |
538 | if response.code == 429: | |
536 | # Retry if the error is a 5xx or a 429 (Too Many | |
537 | # Requests), otherwise just raise a standard | |
538 | # `HttpResponseException` | |
539 | if 500 <= response.code < 600 or response.code == 429: | |
539 | 540 | raise RequestSendFailed(exc, can_retry=True) from exc |
540 | 541 | else: |
541 | 542 | raise exc |
31 | 31 | TCP4ClientEndpoint, |
32 | 32 | TCP6ClientEndpoint, |
33 | 33 | ) |
34 | from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport | |
34 | from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint | |
35 | 35 | from twisted.internet.protocol import Factory, Protocol |
36 | from twisted.internet.tcp import Connection | |
36 | 37 | from twisted.python.failure import Failure |
37 | 38 | |
38 | 39 | logger = logging.getLogger(__name__) |
51 | 52 | format: A callable to format the log record to a string. |
52 | 53 | """ |
53 | 54 | |
54 | transport = attr.ib(type=ITransport) | |
55 | # This is essentially ITCPTransport, but that is missing certain fields | |
56 | # (connected and registerProducer) which are part of the implementation. | |
57 | transport = attr.ib(type=Connection) | |
55 | 58 | _format = attr.ib(type=Callable[[logging.LogRecord], str]) |
56 | 59 | _buffer = attr.ib(type=deque) |
57 | 60 | _paused = attr.ib(default=False, type=bool, init=False) |
148 | 151 | if self._connection_waiter: |
149 | 152 | return |
150 | 153 | |
151 | self._connection_waiter = self._service.whenConnected(failAfterFailures=1) | |
152 | ||
153 | 154 | def fail(failure: Failure) -> None: |
154 | 155 | # If the Deferred was cancelled (e.g. during shutdown) do not try to |
155 | 156 | # reconnect (this will cause an infinite loop of errors). |
162 | 163 | self._connect() |
163 | 164 | |
164 | 165 | def writer(result: Protocol) -> None: |
166 | # Force recognising transport as a Connection and not the more | |
167 | # generic ITransport. | |
168 | transport = result.transport # type: Connection # type: ignore | |
169 | ||
165 | 170 | # We have a connection. If we already have a producer, and its |
166 | 171 | # transport is the same, just trigger a resumeProducing. |
167 | if self._producer and result.transport is self._producer.transport: | |
172 | if self._producer and transport is self._producer.transport: | |
168 | 173 | self._producer.resumeProducing() |
169 | 174 | self._connection_waiter = None |
170 | 175 | return |
176 | 181 | # Make a new producer and start it. |
177 | 182 | self._producer = LogProducer( |
178 | 183 | buffer=self._buffer, |
179 | transport=result.transport, | |
184 | transport=transport, | |
180 | 185 | format=self.format, |
181 | 186 | ) |
182 | result.transport.registerProducer(self._producer, True) | |
187 | transport.registerProducer(self._producer, True) | |
183 | 188 | self._producer.resumeProducing() |
184 | 189 | self._connection_waiter = None |
185 | 190 | |
186 | self._connection_waiter.addCallbacks(writer, fail) | |
191 | deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred | |
192 | deferred.addCallbacks(writer, fail) | |
193 | self._connection_waiter = deferred | |
187 | 194 | |
188 | 195 | def _handle_pressure(self) -> None: |
189 | 196 | """ |
668 | 668 | return g |
669 | 669 | |
670 | 670 | |
671 | def run_in_background(f, *args, **kwargs): | |
671 | def run_in_background(f, *args, **kwargs) -> defer.Deferred: | |
672 | 672 | """Calls a function, ensuring that the current context is restored after |
673 | 673 | return from the function, and that the sentinel context is set once the |
674 | 674 | deferred returned by the function completes. |
696 | 696 | if isinstance(res, types.CoroutineType): |
697 | 697 | res = defer.ensureDeferred(res) |
698 | 698 | |
699 | # At this point we should have a Deferred, if not then f was a synchronous | |
700 | # function, wrap it in a Deferred for consistency. | |
699 | 701 | if not isinstance(res, defer.Deferred): |
700 | return res | |
702 | return defer.succeed(res) | |
701 | 703 | |
702 | 704 | if res.called and not res.paused: |
703 | 705 | # The function should have maintained the logcontext, so we can |
202 | 202 | ) |
203 | 203 | |
204 | 204 | def generate_short_term_login_token( |
205 | self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) | |
205 | self, | |
206 | user_id: str, | |
207 | duration_in_ms: int = (2 * 60 * 1000), | |
208 | auth_provider_id: str = "", | |
206 | 209 | ) -> str: |
207 | """Generate a login token suitable for m.login.token authentication""" | |
210 | """Generate a login token suitable for m.login.token authentication | |
211 | ||
212 | Args: | |
213 | user_id: gives the ID of the user that the token is for | |
214 | ||
215 | duration_in_ms: the time that the token will be valid for | |
216 | ||
217 | auth_provider_id: the ID of the SSO IdP that the user used to authenticate | |
218 | to get this token, if any. This is encoded in the token so that | |
219 | /login can report stats on number of successful logins by IdP. | |
220 | """ | |
208 | 221 | return self._hs.get_macaroon_generator().generate_short_term_login_token( |
209 | user_id, duration_in_ms | |
222 | user_id, | |
223 | auth_provider_id, | |
224 | duration_in_ms, | |
210 | 225 | ) |
211 | 226 | |
212 | 227 | @defer.inlineCallbacks |
275 | 290 | """ |
276 | 291 | self._auth_handler._complete_sso_login( |
277 | 292 | registered_user_id, |
293 | "<unknown>", | |
278 | 294 | request, |
279 | 295 | client_redirect_url, |
280 | 296 | ) |
285 | 301 | request: SynapseRequest, |
286 | 302 | client_redirect_url: str, |
287 | 303 | new_user: bool = False, |
304 | auth_provider_id: str = "<unknown>", | |
288 | 305 | ): |
289 | 306 | """Complete a SSO login by redirecting the user to a page to confirm whether they |
290 | 307 | want their access token sent to `client_redirect_url`, or redirect them to that |
298 | 315 | redirect them directly if whitelisted). |
299 | 316 | new_user: set to true to use wording for the consent appropriate to a user |
300 | 317 | who has just registered. |
318 | auth_provider_id: the ID of the SSO IdP which was used to log in. This | |
319 | is used to track counts of sucessful logins by IdP. | |
301 | 320 | """ |
302 | 321 | await self._auth_handler.complete_sso_login( |
303 | registered_user_id, request, client_redirect_url, new_user=new_user | |
322 | registered_user_id, | |
323 | auth_provider_id, | |
324 | request, | |
325 | client_redirect_url, | |
326 | new_user=new_user, | |
304 | 327 | ) |
305 | 328 | |
306 | 329 | @defer.inlineCallbacks |
15 | 15 | import logging |
16 | 16 | from typing import TYPE_CHECKING, Dict, List, Optional |
17 | 17 | |
18 | from twisted.internet.base import DelayedCall | |
19 | 18 | from twisted.internet.error import AlreadyCalled, AlreadyCancelled |
19 | from twisted.internet.interfaces import IDelayedCall | |
20 | 20 | |
21 | 21 | from synapse.metrics.background_process_metrics import run_as_background_process |
22 | 22 | from synapse.push import Pusher, PusherConfig, ThrottleParams |
65 | 65 | |
66 | 66 | self.store = self.hs.get_datastore() |
67 | 67 | self.email = pusher_config.pushkey |
68 | self.timed_call = None # type: Optional[DelayedCall] | |
68 | self.timed_call = None # type: Optional[IDelayedCall] | |
69 | 69 | self.throttle_params = {} # type: Dict[str, ThrottleParams] |
70 | 70 | self._inited = False |
71 | 71 |
17 | 17 | import re |
18 | 18 | import urllib |
19 | 19 | from inspect import signature |
20 | from typing import Dict, List, Tuple | |
20 | from typing import TYPE_CHECKING, Dict, List, Tuple | |
21 | 21 | |
22 | 22 | from prometheus_client import Counter, Gauge |
23 | 23 | |
26 | 26 | from synapse.logging.opentracing import inject_active_span_byte_dict, trace |
27 | 27 | from synapse.util.caches.response_cache import ResponseCache |
28 | 28 | from synapse.util.stringutils import random_string |
29 | ||
30 | if TYPE_CHECKING: | |
31 | from synapse.server import HomeServer | |
29 | 32 | |
30 | 33 | logger = logging.getLogger(__name__) |
31 | 34 | |
87 | 90 | CACHE = True |
88 | 91 | RETRY_ON_TIMEOUT = True |
89 | 92 | |
90 | def __init__(self, hs): | |
93 | def __init__(self, hs: "HomeServer"): | |
91 | 94 | if self.CACHE: |
92 | 95 | self.response_cache = ResponseCache( |
93 | hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 | |
96 | hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000 | |
94 | 97 | ) # type: ResponseCache[str] |
95 | 98 | |
96 | 99 | # We reserve `instance_name` as a parameter to sending requests, so we |
60 | 60 | is_guest = content["is_guest"] |
61 | 61 | is_appservice_ghost = content["is_appservice_ghost"] |
62 | 62 | |
63 | device_id, access_token = await self.registration_handler.register_device( | |
63 | res = await self.registration_handler.register_device_inner( | |
64 | 64 | user_id, |
65 | 65 | device_id, |
66 | 66 | initial_display_name, |
68 | 68 | is_appservice_ghost=is_appservice_ghost, |
69 | 69 | ) |
70 | 70 | |
71 | return 200, {"device_id": device_id, "access_token": access_token} | |
71 | return 200, res | |
72 | 72 | |
73 | 73 | |
74 | 74 | def register_servlets(hs, http_server): |
47 | 47 | UserIpCommand, |
48 | 48 | UserSyncCommand, |
49 | 49 | ) |
50 | from synapse.replication.tcp.protocol import AbstractConnection | |
50 | from synapse.replication.tcp.protocol import IReplicationConnection | |
51 | 51 | from synapse.replication.tcp.streams import ( |
52 | 52 | STREAMS_MAP, |
53 | 53 | AccountDataStream, |
81 | 81 | |
82 | 82 | # the type of the entries in _command_queues_by_stream |
83 | 83 | _StreamCommandQueue = Deque[ |
84 | Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] | |
84 | Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection] | |
85 | 85 | ] |
86 | 86 | |
87 | 87 | |
173 | 173 | |
174 | 174 | # The currently connected connections. (The list of places we need to send |
175 | 175 | # outgoing replication commands to.) |
176 | self._connections = [] # type: List[AbstractConnection] | |
176 | self._connections = [] # type: List[IReplicationConnection] | |
177 | 177 | |
178 | 178 | LaterGauge( |
179 | 179 | "synapse_replication_tcp_resource_total_connections", |
196 | 196 | |
197 | 197 | # For each connection, the incoming stream names that have received a POSITION |
198 | 198 | # from that connection. |
199 | self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] | |
199 | self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]] | |
200 | 200 | |
201 | 201 | LaterGauge( |
202 | 202 | "synapse_replication_tcp_command_queue", |
219 | 219 | self._server_notices_sender = hs.get_server_notices_sender() |
220 | 220 | |
221 | 221 | def _add_command_to_stream_queue( |
222 | self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] | |
222 | self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand] | |
223 | 223 | ) -> None: |
224 | 224 | """Queue the given received command for processing |
225 | 225 | |
266 | 266 | async def _process_command( |
267 | 267 | self, |
268 | 268 | cmd: Union[PositionCommand, RdataCommand], |
269 | conn: AbstractConnection, | |
269 | conn: IReplicationConnection, | |
270 | 270 | stream_name: str, |
271 | 271 | ) -> None: |
272 | 272 | if isinstance(cmd, PositionCommand): |
301 | 301 | hs, outbound_redis_connection |
302 | 302 | ) |
303 | 303 | hs.get_reactor().connectTCP( |
304 | hs.config.redis.redis_host, | |
304 | hs.config.redis.redis_host.encode(), | |
305 | 305 | hs.config.redis.redis_port, |
306 | 306 | self._factory, |
307 | 307 | ) |
310 | 310 | self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) |
311 | 311 | host = hs.config.worker_replication_host |
312 | 312 | port = hs.config.worker_replication_port |
313 | hs.get_reactor().connectTCP(host, port, self._factory) | |
313 | hs.get_reactor().connectTCP(host.encode(), port, self._factory) | |
314 | 314 | |
315 | 315 | def get_streams(self) -> Dict[str, Stream]: |
316 | 316 | """Get a map from stream name to all streams.""" |
320 | 320 | """Get a list of streams that this instances replicates.""" |
321 | 321 | return self._streams_to_replicate |
322 | 322 | |
323 | def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): | |
323 | def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand): | |
324 | 324 | self.send_positions_to_connection(conn) |
325 | 325 | |
326 | def send_positions_to_connection(self, conn: AbstractConnection): | |
326 | def send_positions_to_connection(self, conn: IReplicationConnection): | |
327 | 327 | """Send current position of all streams this process is source of to |
328 | 328 | the connection. |
329 | 329 | """ |
346 | 346 | ) |
347 | 347 | |
348 | 348 | def on_USER_SYNC( |
349 | self, conn: AbstractConnection, cmd: UserSyncCommand | |
349 | self, conn: IReplicationConnection, cmd: UserSyncCommand | |
350 | 350 | ) -> Optional[Awaitable[None]]: |
351 | 351 | user_sync_counter.inc() |
352 | 352 | |
358 | 358 | return None |
359 | 359 | |
360 | 360 | def on_CLEAR_USER_SYNC( |
361 | self, conn: AbstractConnection, cmd: ClearUserSyncsCommand | |
361 | self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand | |
362 | 362 | ) -> Optional[Awaitable[None]]: |
363 | 363 | if self._is_master: |
364 | 364 | return self._presence_handler.update_external_syncs_clear(cmd.instance_id) |
365 | 365 | else: |
366 | 366 | return None |
367 | 367 | |
368 | def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): | |
368 | def on_FEDERATION_ACK( | |
369 | self, conn: IReplicationConnection, cmd: FederationAckCommand | |
370 | ): | |
369 | 371 | federation_ack_counter.inc() |
370 | 372 | |
371 | 373 | if self._federation_sender: |
372 | 374 | self._federation_sender.federation_ack(cmd.instance_name, cmd.token) |
373 | 375 | |
374 | 376 | def on_USER_IP( |
375 | self, conn: AbstractConnection, cmd: UserIpCommand | |
377 | self, conn: IReplicationConnection, cmd: UserIpCommand | |
376 | 378 | ) -> Optional[Awaitable[None]]: |
377 | 379 | user_ip_cache_counter.inc() |
378 | 380 | |
394 | 396 | assert self._server_notices_sender is not None |
395 | 397 | await self._server_notices_sender.on_user_ip(cmd.user_id) |
396 | 398 | |
397 | def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): | |
399 | def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand): | |
398 | 400 | if cmd.instance_name == self._instance_name: |
399 | 401 | # Ignore RDATA that are just our own echoes |
400 | 402 | return |
411 | 413 | self._add_command_to_stream_queue(conn, cmd) |
412 | 414 | |
413 | 415 | async def _process_rdata( |
414 | self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand | |
416 | self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand | |
415 | 417 | ) -> None: |
416 | 418 | """Process an RDATA command |
417 | 419 | |
485 | 487 | stream_name, instance_name, token, rows |
486 | 488 | ) |
487 | 489 | |
488 | def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): | |
490 | def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand): | |
489 | 491 | if cmd.instance_name == self._instance_name: |
490 | 492 | # Ignore POSITION that are just our own echoes |
491 | 493 | return |
495 | 497 | self._add_command_to_stream_queue(conn, cmd) |
496 | 498 | |
497 | 499 | async def _process_position( |
498 | self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand | |
500 | self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand | |
499 | 501 | ) -> None: |
500 | 502 | """Process a POSITION command |
501 | 503 | |
552 | 554 | |
553 | 555 | self._streams_by_connection.setdefault(conn, set()).add(stream_name) |
554 | 556 | |
555 | def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): | |
557 | def on_REMOTE_SERVER_UP( | |
558 | self, conn: IReplicationConnection, cmd: RemoteServerUpCommand | |
559 | ): | |
556 | 560 | """"Called when get a new REMOTE_SERVER_UP command.""" |
557 | 561 | self._replication_data_handler.on_remote_server_up(cmd.data) |
558 | 562 | |
575 | 579 | # between two instances, but that is not currently supported). |
576 | 580 | self.send_command(cmd, ignore_conn=conn) |
577 | 581 | |
578 | def new_connection(self, connection: AbstractConnection): | |
582 | def new_connection(self, connection: IReplicationConnection): | |
579 | 583 | """Called when we have a new connection.""" |
580 | 584 | self._connections.append(connection) |
581 | 585 | |
602 | 606 | UserSyncCommand(self._instance_id, user_id, True, now) |
603 | 607 | ) |
604 | 608 | |
605 | def lost_connection(self, connection: AbstractConnection): | |
609 | def lost_connection(self, connection: IReplicationConnection): | |
606 | 610 | """Called when a connection is closed/lost.""" |
607 | 611 | # we no longer need _streams_by_connection for this connection. |
608 | 612 | streams = self._streams_by_connection.pop(connection, None) |
623 | 627 | return bool(self._connections) |
624 | 628 | |
625 | 629 | def send_command( |
626 | self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None | |
630 | self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None | |
627 | 631 | ): |
628 | 632 | """Send a command to all connected connections. |
629 | 633 |
45 | 45 | > ERROR server stopping |
46 | 46 | * connection closed by server * |
47 | 47 | """ |
48 | import abc | |
49 | 48 | import fcntl |
50 | 49 | import logging |
51 | 50 | import struct |
53 | 52 | from typing import TYPE_CHECKING, List, Optional |
54 | 53 | |
55 | 54 | from prometheus_client import Counter |
55 | from zope.interface import Interface, implementer | |
56 | 56 | |
57 | 57 | from twisted.internet import task |
58 | from twisted.internet.tcp import Connection | |
58 | 59 | from twisted.protocols.basic import LineOnlyReceiver |
59 | 60 | from twisted.python.failure import Failure |
60 | 61 | |
120 | 121 | CLOSED = "closed" |
121 | 122 | |
122 | 123 | |
124 | class IReplicationConnection(Interface): | |
125 | """An interface for replication connections.""" | |
126 | ||
127 | def send_command(cmd: Command): | |
128 | """Send the command down the connection""" | |
129 | ||
130 | ||
131 | @implementer(IReplicationConnection) | |
123 | 132 | class BaseReplicationStreamProtocol(LineOnlyReceiver): |
124 | 133 | """Base replication protocol shared between client and server. |
125 | 134 | |
136 | 145 | (if they send a `PING` command) |
137 | 146 | """ |
138 | 147 | |
148 | # The transport is going to be an ITCPTransport, but that doesn't have the | |
149 | # (un)registerProducer methods, those are only on the implementation. | |
150 | transport = None # type: Connection | |
151 | ||
139 | 152 | delimiter = b"\n" |
140 | 153 | |
141 | 154 | # Valid commands we expect to receive |
180 | 193 | |
181 | 194 | connected_connections.append(self) # Register connection for metrics |
182 | 195 | |
196 | assert self.transport is not None | |
183 | 197 | self.transport.registerProducer(self, True) # For the *Producing callbacks |
184 | 198 | |
185 | 199 | self._send_pending_commands() |
204 | 218 | logger.info( |
205 | 219 | "[%s] Failed to close connection gracefully, aborting", self.id() |
206 | 220 | ) |
221 | assert self.transport is not None | |
207 | 222 | self.transport.abortConnection() |
208 | 223 | else: |
209 | 224 | if now - self.last_sent_command >= PING_TIME: |
293 | 308 | def close(self): |
294 | 309 | logger.warning("[%s] Closing connection", self.id()) |
295 | 310 | self.time_we_closed = self.clock.time_msec() |
311 | assert self.transport is not None | |
296 | 312 | self.transport.loseConnection() |
297 | 313 | self.on_connection_closed() |
298 | 314 | |
390 | 406 | def connectionLost(self, reason): |
391 | 407 | logger.info("[%s] Replication connection closed: %r", self.id(), reason) |
392 | 408 | if isinstance(reason, Failure): |
409 | assert reason.type is not None | |
393 | 410 | connection_close_counter.labels(reason.type.__name__).inc() |
394 | 411 | else: |
395 | 412 | connection_close_counter.labels(reason.__class__.__name__).inc() |
494 | 511 | self.send_command(ReplicateCommand()) |
495 | 512 | |
496 | 513 | |
497 | class AbstractConnection(abc.ABC): | |
498 | """An interface for replication connections.""" | |
499 | ||
500 | @abc.abstractmethod | |
501 | def send_command(self, cmd: Command): | |
502 | """Send the command down the connection""" | |
503 | pass | |
504 | ||
505 | ||
506 | # This tells python that `BaseReplicationStreamProtocol` implements the | |
507 | # interface. | |
508 | AbstractConnection.register(BaseReplicationStreamProtocol) | |
509 | ||
510 | ||
511 | 514 | # The following simply registers metrics for the replication connections |
512 | 515 | |
513 | 516 | pending_commands = LaterGauge( |
18 | 18 | |
19 | 19 | import attr |
20 | 20 | import txredisapi |
21 | from zope.interface import implementer | |
22 | ||
23 | from twisted.internet.address import IPv4Address, IPv6Address | |
24 | from twisted.internet.interfaces import IAddress, IConnector | |
25 | from twisted.python.failure import Failure | |
21 | 26 | |
22 | 27 | from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable |
23 | 28 | from synapse.metrics.background_process_metrics import ( |
31 | 36 | parse_command_from_line, |
32 | 37 | ) |
33 | 38 | from synapse.replication.tcp.protocol import ( |
34 | AbstractConnection, | |
39 | IReplicationConnection, | |
35 | 40 | tcp_inbound_commands_counter, |
36 | 41 | tcp_outbound_commands_counter, |
37 | 42 | ) |
61 | 66 | pass |
62 | 67 | |
63 | 68 | |
64 | class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): | |
69 | @implementer(IReplicationConnection) | |
70 | class RedisSubscriber(txredisapi.SubscriberProtocol): | |
65 | 71 | """Connection to redis subscribed to replication stream. |
66 | 72 | |
67 | 73 | This class fulfils two functions: |
70 | 76 | connection, parsing *incoming* messages into replication commands, and passing them |
71 | 77 | to `ReplicationCommandHandler` |
72 | 78 | |
73 | (b) it implements the AbstractConnection API, where it sends *outgoing* commands | |
79 | (b) it implements the IReplicationConnection API, where it sends *outgoing* commands | |
74 | 80 | onto outbound_redis_connection. |
75 | 81 | |
76 | 82 | Due to the vagaries of `txredisapi` we don't want to have a custom |
252 | 258 | except Exception: |
253 | 259 | logger.warning("Failed to send ping to a redis connection") |
254 | 260 | |
261 | # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but | |
262 | # it's rubbish. We add our own here. | |
263 | ||
264 | def startedConnecting(self, connector: IConnector): | |
265 | logger.info( | |
266 | "Connecting to redis server %s", format_address(connector.getDestination()) | |
267 | ) | |
268 | super().startedConnecting(connector) | |
269 | ||
270 | def clientConnectionFailed(self, connector: IConnector, reason: Failure): | |
271 | logger.info( | |
272 | "Connection to redis server %s failed: %s", | |
273 | format_address(connector.getDestination()), | |
274 | reason.value, | |
275 | ) | |
276 | super().clientConnectionFailed(connector, reason) | |
277 | ||
278 | def clientConnectionLost(self, connector: IConnector, reason: Failure): | |
279 | logger.info( | |
280 | "Connection to redis server %s lost: %s", | |
281 | format_address(connector.getDestination()), | |
282 | reason.value, | |
283 | ) | |
284 | super().clientConnectionLost(connector, reason) | |
285 | ||
286 | ||
287 | def format_address(address: IAddress) -> str: | |
288 | if isinstance(address, (IPv4Address, IPv6Address)): | |
289 | return "%s:%i" % (address.host, address.port) | |
290 | return str(address) | |
291 | ||
255 | 292 | |
256 | 293 | class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): |
257 | 294 | """This is a reconnecting factory that connects to redis and immediately |
327 | 364 | factory.continueTrying = reconnect |
328 | 365 | |
329 | 366 | reactor = hs.get_reactor() |
330 | reactor.connectTCP(host, port, factory, 30) | |
367 | reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None) | |
331 | 368 | |
332 | 369 | return factory.handler |
14 | 14 | |
15 | 15 | import re |
16 | 16 | |
17 | import twisted.web.server | |
18 | ||
19 | import synapse.api.auth | |
17 | from synapse.api.auth import Auth | |
20 | 18 | from synapse.api.errors import AuthError |
19 | from synapse.http.site import SynapseRequest | |
21 | 20 | from synapse.types import UserID |
22 | 21 | |
23 | 22 | |
36 | 35 | return patterns |
37 | 36 | |
38 | 37 | |
39 | async def assert_requester_is_admin( | |
40 | auth: synapse.api.auth.Auth, request: twisted.web.server.Request | |
41 | ) -> None: | |
38 | async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None: | |
42 | 39 | """Verify that the requester is an admin user |
43 | 40 | |
44 | 41 | Args: |
45 | auth: api.auth.Auth singleton | |
42 | auth: Auth singleton | |
46 | 43 | request: incoming request |
47 | 44 | |
48 | 45 | Raises: |
52 | 49 | await assert_user_is_admin(auth, requester.user) |
53 | 50 | |
54 | 51 | |
55 | async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None: | |
52 | async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: | |
56 | 53 | """Verify that the given user is an admin user |
57 | 54 | |
58 | 55 | Args: |
59 | auth: api.auth.Auth singleton | |
56 | auth: Auth singleton | |
60 | 57 | user_id: user to check |
61 | 58 | |
62 | 59 | Raises: |
16 | 16 | import logging |
17 | 17 | from typing import TYPE_CHECKING, Tuple |
18 | 18 | |
19 | from twisted.web.server import Request | |
20 | ||
21 | 19 | from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError |
22 | 20 | from synapse.http.servlet import RestServlet, parse_boolean, parse_integer |
21 | from synapse.http.site import SynapseRequest | |
23 | 22 | from synapse.rest.admin._base import ( |
24 | 23 | admin_patterns, |
25 | 24 | assert_requester_is_admin, |
49 | 48 | self.store = hs.get_datastore() |
50 | 49 | self.auth = hs.get_auth() |
51 | 50 | |
52 | async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: | |
51 | async def on_POST( | |
52 | self, request: SynapseRequest, room_id: str | |
53 | ) -> Tuple[int, JsonDict]: | |
53 | 54 | requester = await self.auth.get_user_by_req(request) |
54 | 55 | await assert_user_is_admin(self.auth, requester.user) |
55 | 56 | |
74 | 75 | self.store = hs.get_datastore() |
75 | 76 | self.auth = hs.get_auth() |
76 | 77 | |
77 | async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: | |
78 | async def on_POST( | |
79 | self, request: SynapseRequest, user_id: str | |
80 | ) -> Tuple[int, JsonDict]: | |
78 | 81 | requester = await self.auth.get_user_by_req(request) |
79 | 82 | await assert_user_is_admin(self.auth, requester.user) |
80 | 83 | |
102 | 105 | self.auth = hs.get_auth() |
103 | 106 | |
104 | 107 | async def on_POST( |
105 | self, request: Request, server_name: str, media_id: str | |
108 | self, request: SynapseRequest, server_name: str, media_id: str | |
106 | 109 | ) -> Tuple[int, JsonDict]: |
107 | 110 | requester = await self.auth.get_user_by_req(request) |
108 | 111 | await assert_user_is_admin(self.auth, requester.user) |
126 | 129 | self.store = hs.get_datastore() |
127 | 130 | self.auth = hs.get_auth() |
128 | 131 | |
129 | async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: | |
132 | async def on_POST( | |
133 | self, request: SynapseRequest, media_id: str | |
134 | ) -> Tuple[int, JsonDict]: | |
130 | 135 | requester = await self.auth.get_user_by_req(request) |
131 | 136 | await assert_user_is_admin(self.auth, requester.user) |
132 | 137 | |
147 | 152 | self.store = hs.get_datastore() |
148 | 153 | self.auth = hs.get_auth() |
149 | 154 | |
150 | async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: | |
155 | async def on_GET( | |
156 | self, request: SynapseRequest, room_id: str | |
157 | ) -> Tuple[int, JsonDict]: | |
151 | 158 | requester = await self.auth.get_user_by_req(request) |
152 | 159 | is_admin = await self.auth.is_server_admin(requester.user) |
153 | 160 | if not is_admin: |
165 | 172 | self.media_repository = hs.get_media_repository() |
166 | 173 | self.auth = hs.get_auth() |
167 | 174 | |
168 | async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: | |
175 | async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | |
169 | 176 | await assert_requester_is_admin(self.auth, request) |
170 | 177 | |
171 | 178 | before_ts = parse_integer(request, "before_ts", required=True) |
188 | 195 | self.media_repository = hs.get_media_repository() |
189 | 196 | |
190 | 197 | async def on_DELETE( |
191 | self, request: Request, server_name: str, media_id: str | |
198 | self, request: SynapseRequest, server_name: str, media_id: str | |
192 | 199 | ) -> Tuple[int, JsonDict]: |
193 | 200 | await assert_requester_is_admin(self.auth, request) |
194 | 201 | |
217 | 224 | self.server_name = hs.hostname |
218 | 225 | self.media_repository = hs.get_media_repository() |
219 | 226 | |
220 | async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: | |
227 | async def on_POST( | |
228 | self, request: SynapseRequest, server_name: str | |
229 | ) -> Tuple[int, JsonDict]: | |
221 | 230 | await assert_requester_is_admin(self.auth, request) |
222 | 231 | |
223 | 232 | before_ts = parse_integer(request, "before_ts", required=True) |
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | from typing import TYPE_CHECKING, Tuple | |
15 | ||
14 | 16 | from synapse.http.servlet import ( |
15 | 17 | RestServlet, |
16 | 18 | assert_params_in_dict, |
17 | 19 | parse_json_object_from_request, |
18 | 20 | ) |
21 | from synapse.http.site import SynapseRequest | |
19 | 22 | from synapse.rest.admin import assert_requester_is_admin |
20 | 23 | from synapse.rest.admin._base import admin_patterns |
24 | from synapse.types import JsonDict | |
25 | ||
26 | if TYPE_CHECKING: | |
27 | from synapse.server import HomeServer | |
21 | 28 | |
22 | 29 | |
23 | 30 | class PurgeRoomServlet(RestServlet): |
35 | 42 | |
36 | 43 | PATTERNS = admin_patterns("/purge_room$") |
37 | 44 | |
38 | def __init__(self, hs): | |
39 | """ | |
40 | Args: | |
41 | hs (synapse.server.HomeServer): server | |
42 | """ | |
45 | def __init__(self, hs: "HomeServer"): | |
43 | 46 | self.hs = hs |
44 | 47 | self.auth = hs.get_auth() |
45 | 48 | self.pagination_handler = hs.get_pagination_handler() |
46 | 49 | |
47 | async def on_POST(self, request): | |
50 | async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | |
48 | 51 | await assert_requester_is_admin(self.auth, request) |
49 | 52 | |
50 | 53 | body = parse_json_object_from_request(request) |
684 | 684 | results["events_after"], time_now |
685 | 685 | ) |
686 | 686 | results["state"] = await self._event_serializer.serialize_events( |
687 | results["state"], time_now | |
687 | results["state"], | |
688 | time_now, | |
689 | # No need to bundle aggregations for state events | |
690 | bundle_aggregations=False, | |
688 | 691 | ) |
689 | 692 | |
690 | 693 | return 200, results |
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | from typing import TYPE_CHECKING, Optional, Tuple | |
15 | ||
14 | 16 | from synapse.api.constants import EventTypes |
15 | 17 | from synapse.api.errors import SynapseError |
18 | from synapse.http.server import HttpServer | |
16 | 19 | from synapse.http.servlet import ( |
17 | 20 | RestServlet, |
18 | 21 | assert_params_in_dict, |
19 | 22 | parse_json_object_from_request, |
20 | 23 | ) |
24 | from synapse.http.site import SynapseRequest | |
21 | 25 | from synapse.rest.admin import assert_requester_is_admin |
22 | 26 | from synapse.rest.admin._base import admin_patterns |
23 | 27 | from synapse.rest.client.transactions import HttpTransactionCache |
24 | from synapse.types import UserID | |
28 | from synapse.types import JsonDict, UserID | |
29 | ||
30 | if TYPE_CHECKING: | |
31 | from synapse.server import HomeServer | |
25 | 32 | |
26 | 33 | |
27 | 34 | class SendServerNoticeServlet(RestServlet): |
43 | 50 | } |
44 | 51 | """ |
45 | 52 | |
46 | def __init__(self, hs): | |
47 | """ | |
48 | Args: | |
49 | hs (synapse.server.HomeServer): server | |
50 | """ | |
53 | def __init__(self, hs: "HomeServer"): | |
51 | 54 | self.hs = hs |
52 | 55 | self.auth = hs.get_auth() |
53 | 56 | self.txns = HttpTransactionCache(hs) |
54 | 57 | self.snm = hs.get_server_notices_manager() |
55 | 58 | |
56 | def register(self, json_resource): | |
59 | def register(self, json_resource: HttpServer): | |
57 | 60 | PATTERN = "/send_server_notice" |
58 | 61 | json_resource.register_paths( |
59 | 62 | "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__ |
65 | 68 | self.__class__.__name__, |
66 | 69 | ) |
67 | 70 | |
68 | async def on_POST(self, request, txn_id=None): | |
71 | async def on_POST( | |
72 | self, request: SynapseRequest, txn_id: Optional[str] = None | |
73 | ) -> Tuple[int, JsonDict]: | |
69 | 74 | await assert_requester_is_admin(self.auth, request) |
70 | 75 | body = parse_json_object_from_request(request) |
71 | 76 | assert_params_in_dict(body, ("user_id", "content")) |
89 | 94 | |
90 | 95 | return 200, {"event_id": event.event_id} |
91 | 96 | |
92 | def on_PUT(self, request, txn_id): | |
97 | def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]: | |
93 | 98 | return self.txns.fetch_or_execute_request( |
94 | 99 | request, self.on_POST, request, txn_id |
95 | 100 | ) |
268 | 268 | target_user.to_string(), False, requester, by_admin=True |
269 | 269 | ) |
270 | 270 | elif not deactivate and user["deactivated"]: |
271 | if "password" not in body: | |
271 | if ( | |
272 | "password" not in body | |
273 | and self.hs.config.password_localdb_enabled | |
274 | ): | |
272 | 275 | raise SynapseError( |
273 | 276 | 400, "Must provide a password to re-activate an account." |
274 | 277 | ) |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | import logging |
16 | import re | |
16 | 17 | from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional |
17 | 18 | |
18 | 19 | from synapse.api.errors import Codes, LoginError, SynapseError |
19 | 20 | from synapse.api.ratelimiting import Ratelimiter |
21 | from synapse.api.urls import CLIENT_API_PREFIX | |
20 | 22 | from synapse.appservice import ApplicationService |
21 | 23 | from synapse.handlers.sso import SsoIdentityProvider |
22 | 24 | from synapse.http import get_request_uri |
93 | 95 | flows.append({"type": LoginRestServlet.CAS_TYPE}) |
94 | 96 | |
95 | 97 | if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: |
96 | sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict | |
98 | sso_flow = { | |
99 | "type": LoginRestServlet.SSO_TYPE, | |
100 | "identity_providers": [ | |
101 | _get_auth_flow_dict_for_idp( | |
102 | idp, | |
103 | ) | |
104 | for idp in self._sso_handler.get_identity_providers().values() | |
105 | ], | |
106 | } # type: JsonDict | |
97 | 107 | |
98 | 108 | if self._msc2858_enabled: |
109 | # backwards-compatibility support for clients which don't | |
110 | # support the stable API yet | |
99 | 111 | sso_flow["org.matrix.msc2858.identity_providers"] = [ |
100 | _get_auth_flow_dict_for_idp(idp) | |
112 | _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True) | |
101 | 113 | for idp in self._sso_handler.get_identity_providers().values() |
102 | 114 | ] |
103 | 115 | |
218 | 230 | callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, |
219 | 231 | create_non_existent_users: bool = False, |
220 | 232 | ratelimit: bool = True, |
233 | auth_provider_id: Optional[str] = None, | |
221 | 234 | ) -> Dict[str, str]: |
222 | 235 | """Called when we've successfully authed the user and now need to |
223 | 236 | actually login them in (e.g. create devices). This gets called on |
233 | 246 | create_non_existent_users: Whether to create the user if they don't |
234 | 247 | exist. Defaults to False. |
235 | 248 | ratelimit: Whether to ratelimit the login request. |
249 | auth_provider_id: The SSO IdP the user used, if any (just used for the | |
250 | prometheus metrics). | |
236 | 251 | |
237 | 252 | Returns: |
238 | 253 | result: Dictionary of account information after successful login. |
255 | 270 | device_id = login_submission.get("device_id") |
256 | 271 | initial_display_name = login_submission.get("initial_device_display_name") |
257 | 272 | device_id, access_token = await self.registration_handler.register_device( |
258 | user_id, device_id, initial_display_name | |
273 | user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id | |
259 | 274 | ) |
260 | 275 | |
261 | 276 | result = { |
282 | 297 | """ |
283 | 298 | token = login_submission["token"] |
284 | 299 | auth_handler = self.auth_handler |
285 | user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( | |
286 | token | |
287 | ) | |
300 | res = await auth_handler.validate_short_term_login_token(token) | |
288 | 301 | |
289 | 302 | return await self._complete_login( |
290 | user_id, login_submission, self.auth_handler._sso_login_callback | |
303 | res.user_id, | |
304 | login_submission, | |
305 | self.auth_handler._sso_login_callback, | |
306 | auth_provider_id=res.auth_provider_id, | |
291 | 307 | ) |
292 | 308 | |
293 | 309 | async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: |
326 | 342 | return result |
327 | 343 | |
328 | 344 | |
329 | def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: | |
345 | def _get_auth_flow_dict_for_idp( | |
346 | idp: SsoIdentityProvider, use_unstable_brands: bool = False | |
347 | ) -> JsonDict: | |
330 | 348 | """Return an entry for the login flow dict |
331 | 349 | |
332 | 350 | Returns an entry suitable for inclusion in "identity_providers" in the |
333 | 351 | response to GET /_matrix/client/r0/login |
352 | ||
353 | Args: | |
354 | idp: the identity provider to describe | |
355 | use_unstable_brands: whether we should use brand identifiers suitable | |
356 | for the unstable API | |
334 | 357 | """ |
335 | 358 | e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict |
336 | 359 | if idp.idp_icon: |
337 | 360 | e["icon"] = idp.idp_icon |
338 | 361 | if idp.idp_brand: |
339 | 362 | e["brand"] = idp.idp_brand |
363 | # use the stable brand identifier if the unstable identifier isn't defined. | |
364 | if use_unstable_brands and idp.unstable_idp_brand: | |
365 | e["brand"] = idp.unstable_idp_brand | |
340 | 366 | return e |
341 | 367 | |
342 | 368 | |
343 | 369 | class SsoRedirectServlet(RestServlet): |
344 | PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) | |
370 | PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ | |
371 | re.compile( | |
372 | "^" | |
373 | + CLIENT_API_PREFIX | |
374 | + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$" | |
375 | ) | |
376 | ] | |
345 | 377 | |
346 | 378 | def __init__(self, hs: "HomeServer"): |
347 | 379 | # make sure that the relevant handlers are instantiated, so that they |
359 | 391 | def register(self, http_server: HttpServer) -> None: |
360 | 392 | super().register(http_server) |
361 | 393 | if self._msc2858_enabled: |
362 | # expose additional endpoint for MSC2858 support | |
394 | # expose additional endpoint for MSC2858 support: backwards-compat support | |
395 | # for clients which don't yet support the stable endpoints. | |
363 | 396 | http_server.register_paths( |
364 | 397 | "GET", |
365 | 398 | client_patterns( |
670 | 670 | results["events_after"], time_now |
671 | 671 | ) |
672 | 672 | results["state"] = await self._event_serializer.serialize_events( |
673 | results["state"], time_now | |
673 | results["state"], | |
674 | time_now, | |
675 | # No need to bundle aggregations for state events | |
676 | bundle_aggregations=False, | |
674 | 677 | ) |
675 | 678 | |
676 | 679 | return 200, results |
31 | 31 | assert_params_in_dict, |
32 | 32 | parse_json_object_from_request, |
33 | 33 | ) |
34 | from synapse.http.site import SynapseRequest | |
34 | 35 | from synapse.types import GroupID, JsonDict |
35 | 36 | |
36 | 37 | from ._base import client_patterns |
69 | 70 | self.groups_handler = hs.get_groups_local_handler() |
70 | 71 | |
71 | 72 | @_validate_group_id |
72 | async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
73 | async def on_GET( | |
74 | self, request: SynapseRequest, group_id: str | |
75 | ) -> Tuple[int, JsonDict]: | |
73 | 76 | requester = await self.auth.get_user_by_req(request, allow_guest=True) |
74 | 77 | requester_user_id = requester.user.to_string() |
75 | 78 | |
80 | 83 | return 200, group_description |
81 | 84 | |
82 | 85 | @_validate_group_id |
83 | async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
86 | async def on_POST( | |
87 | self, request: SynapseRequest, group_id: str | |
88 | ) -> Tuple[int, JsonDict]: | |
84 | 89 | requester = await self.auth.get_user_by_req(request) |
85 | 90 | requester_user_id = requester.user.to_string() |
86 | 91 | |
110 | 115 | self.groups_handler = hs.get_groups_local_handler() |
111 | 116 | |
112 | 117 | @_validate_group_id |
113 | async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
118 | async def on_GET( | |
119 | self, request: SynapseRequest, group_id: str | |
120 | ) -> Tuple[int, JsonDict]: | |
114 | 121 | requester = await self.auth.get_user_by_req(request, allow_guest=True) |
115 | 122 | requester_user_id = requester.user.to_string() |
116 | 123 | |
143 | 150 | |
144 | 151 | @_validate_group_id |
145 | 152 | async def on_PUT( |
146 | self, request: Request, group_id: str, category_id: Optional[str], room_id: str | |
153 | self, | |
154 | request: SynapseRequest, | |
155 | group_id: str, | |
156 | category_id: Optional[str], | |
157 | room_id: str, | |
147 | 158 | ): |
148 | 159 | requester = await self.auth.get_user_by_req(request) |
149 | 160 | requester_user_id = requester.user.to_string() |
175 | 186 | |
176 | 187 | @_validate_group_id |
177 | 188 | async def on_DELETE( |
178 | self, request: Request, group_id: str, category_id: str, room_id: str | |
189 | self, request: SynapseRequest, group_id: str, category_id: str, room_id: str | |
179 | 190 | ): |
180 | 191 | requester = await self.auth.get_user_by_req(request) |
181 | 192 | requester_user_id = requester.user.to_string() |
205 | 216 | |
206 | 217 | @_validate_group_id |
207 | 218 | async def on_GET( |
208 | self, request: Request, group_id: str, category_id: str | |
219 | self, request: SynapseRequest, group_id: str, category_id: str | |
209 | 220 | ) -> Tuple[int, JsonDict]: |
210 | 221 | requester = await self.auth.get_user_by_req(request, allow_guest=True) |
211 | 222 | requester_user_id = requester.user.to_string() |
218 | 229 | |
219 | 230 | @_validate_group_id |
220 | 231 | async def on_PUT( |
221 | self, request: Request, group_id: str, category_id: str | |
232 | self, request: SynapseRequest, group_id: str, category_id: str | |
222 | 233 | ) -> Tuple[int, JsonDict]: |
223 | 234 | requester = await self.auth.get_user_by_req(request) |
224 | 235 | requester_user_id = requester.user.to_string() |
246 | 257 | |
247 | 258 | @_validate_group_id |
248 | 259 | async def on_DELETE( |
249 | self, request: Request, group_id: str, category_id: str | |
260 | self, request: SynapseRequest, group_id: str, category_id: str | |
250 | 261 | ) -> Tuple[int, JsonDict]: |
251 | 262 | requester = await self.auth.get_user_by_req(request) |
252 | 263 | requester_user_id = requester.user.to_string() |
273 | 284 | self.groups_handler = hs.get_groups_local_handler() |
274 | 285 | |
275 | 286 | @_validate_group_id |
276 | async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
287 | async def on_GET( | |
288 | self, request: SynapseRequest, group_id: str | |
289 | ) -> Tuple[int, JsonDict]: | |
277 | 290 | requester = await self.auth.get_user_by_req(request, allow_guest=True) |
278 | 291 | requester_user_id = requester.user.to_string() |
279 | 292 | |
297 | 310 | |
298 | 311 | @_validate_group_id |
299 | 312 | async def on_GET( |
300 | self, request: Request, group_id: str, role_id: str | |
313 | self, request: SynapseRequest, group_id: str, role_id: str | |
301 | 314 | ) -> Tuple[int, JsonDict]: |
302 | 315 | requester = await self.auth.get_user_by_req(request, allow_guest=True) |
303 | 316 | requester_user_id = requester.user.to_string() |
310 | 323 | |
311 | 324 | @_validate_group_id |
312 | 325 | async def on_PUT( |
313 | self, request: Request, group_id: str, role_id: str | |
326 | self, request: SynapseRequest, group_id: str, role_id: str | |
314 | 327 | ) -> Tuple[int, JsonDict]: |
315 | 328 | requester = await self.auth.get_user_by_req(request) |
316 | 329 | requester_user_id = requester.user.to_string() |
338 | 351 | |
339 | 352 | @_validate_group_id |
340 | 353 | async def on_DELETE( |
341 | self, request: Request, group_id: str, role_id: str | |
354 | self, request: SynapseRequest, group_id: str, role_id: str | |
342 | 355 | ) -> Tuple[int, JsonDict]: |
343 | 356 | requester = await self.auth.get_user_by_req(request) |
344 | 357 | requester_user_id = requester.user.to_string() |
365 | 378 | self.groups_handler = hs.get_groups_local_handler() |
366 | 379 | |
367 | 380 | @_validate_group_id |
368 | async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
381 | async def on_GET( | |
382 | self, request: SynapseRequest, group_id: str | |
383 | ) -> Tuple[int, JsonDict]: | |
369 | 384 | requester = await self.auth.get_user_by_req(request, allow_guest=True) |
370 | 385 | requester_user_id = requester.user.to_string() |
371 | 386 | |
398 | 413 | |
399 | 414 | @_validate_group_id |
400 | 415 | async def on_PUT( |
401 | self, request: Request, group_id: str, role_id: Optional[str], user_id: str | |
416 | self, | |
417 | request: SynapseRequest, | |
418 | group_id: str, | |
419 | role_id: Optional[str], | |
420 | user_id: str, | |
402 | 421 | ) -> Tuple[int, JsonDict]: |
403 | 422 | requester = await self.auth.get_user_by_req(request) |
404 | 423 | requester_user_id = requester.user.to_string() |
430 | 449 | |
431 | 450 | @_validate_group_id |
432 | 451 | async def on_DELETE( |
433 | self, request: Request, group_id: str, role_id: str, user_id: str | |
452 | self, request: SynapseRequest, group_id: str, role_id: str, user_id: str | |
434 | 453 | ): |
435 | 454 | requester = await self.auth.get_user_by_req(request) |
436 | 455 | requester_user_id = requester.user.to_string() |
457 | 476 | self.groups_handler = hs.get_groups_local_handler() |
458 | 477 | |
459 | 478 | @_validate_group_id |
460 | async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
479 | async def on_GET( | |
480 | self, request: SynapseRequest, group_id: str | |
481 | ) -> Tuple[int, JsonDict]: | |
461 | 482 | requester = await self.auth.get_user_by_req(request, allow_guest=True) |
462 | 483 | requester_user_id = requester.user.to_string() |
463 | 484 | |
480 | 501 | self.groups_handler = hs.get_groups_local_handler() |
481 | 502 | |
482 | 503 | @_validate_group_id |
483 | async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
504 | async def on_GET( | |
505 | self, request: SynapseRequest, group_id: str | |
506 | ) -> Tuple[int, JsonDict]: | |
484 | 507 | requester = await self.auth.get_user_by_req(request, allow_guest=True) |
485 | 508 | requester_user_id = requester.user.to_string() |
486 | 509 | |
503 | 526 | self.groups_handler = hs.get_groups_local_handler() |
504 | 527 | |
505 | 528 | @_validate_group_id |
506 | async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
529 | async def on_GET( | |
530 | self, request: SynapseRequest, group_id: str | |
531 | ) -> Tuple[int, JsonDict]: | |
507 | 532 | requester = await self.auth.get_user_by_req(request) |
508 | 533 | requester_user_id = requester.user.to_string() |
509 | 534 | |
525 | 550 | self.groups_handler = hs.get_groups_local_handler() |
526 | 551 | |
527 | 552 | @_validate_group_id |
528 | async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
553 | async def on_PUT( | |
554 | self, request: SynapseRequest, group_id: str | |
555 | ) -> Tuple[int, JsonDict]: | |
529 | 556 | requester = await self.auth.get_user_by_req(request) |
530 | 557 | requester_user_id = requester.user.to_string() |
531 | 558 | |
553 | 580 | self.groups_handler = hs.get_groups_local_handler() |
554 | 581 | self.server_name = hs.hostname |
555 | 582 | |
556 | async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: | |
583 | async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | |
557 | 584 | requester = await self.auth.get_user_by_req(request) |
558 | 585 | requester_user_id = requester.user.to_string() |
559 | 586 | |
597 | 624 | |
598 | 625 | @_validate_group_id |
599 | 626 | async def on_PUT( |
600 | self, request: Request, group_id: str, room_id: str | |
627 | self, request: SynapseRequest, group_id: str, room_id: str | |
601 | 628 | ) -> Tuple[int, JsonDict]: |
602 | 629 | requester = await self.auth.get_user_by_req(request) |
603 | 630 | requester_user_id = requester.user.to_string() |
614 | 641 | |
615 | 642 | @_validate_group_id |
616 | 643 | async def on_DELETE( |
617 | self, request: Request, group_id: str, room_id: str | |
644 | self, request: SynapseRequest, group_id: str, room_id: str | |
618 | 645 | ) -> Tuple[int, JsonDict]: |
619 | 646 | requester = await self.auth.get_user_by_req(request) |
620 | 647 | requester_user_id = requester.user.to_string() |
645 | 672 | |
646 | 673 | @_validate_group_id |
647 | 674 | async def on_PUT( |
648 | self, request: Request, group_id: str, room_id: str, config_key: str | |
675 | self, request: SynapseRequest, group_id: str, room_id: str, config_key: str | |
649 | 676 | ): |
650 | 677 | requester = await self.auth.get_user_by_req(request) |
651 | 678 | requester_user_id = requester.user.to_string() |
677 | 704 | self.is_mine_id = hs.is_mine_id |
678 | 705 | |
679 | 706 | @_validate_group_id |
680 | async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: | |
707 | async def on_PUT( | |
708 | self, request: SynapseRequest, group_id, user_id | |
709 | ) -> Tuple[int, JsonDict]: | |
681 | 710 | requester = await self.auth.get_user_by_req(request) |
682 | 711 | requester_user_id = requester.user.to_string() |
683 | 712 | |
707 | 736 | self.groups_handler = hs.get_groups_local_handler() |
708 | 737 | |
709 | 738 | @_validate_group_id |
710 | async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: | |
739 | async def on_PUT( | |
740 | self, request: SynapseRequest, group_id, user_id | |
741 | ) -> Tuple[int, JsonDict]: | |
711 | 742 | requester = await self.auth.get_user_by_req(request) |
712 | 743 | requester_user_id = requester.user.to_string() |
713 | 744 | |
734 | 765 | self.groups_handler = hs.get_groups_local_handler() |
735 | 766 | |
736 | 767 | @_validate_group_id |
737 | async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
768 | async def on_PUT( | |
769 | self, request: SynapseRequest, group_id: str | |
770 | ) -> Tuple[int, JsonDict]: | |
738 | 771 | requester = await self.auth.get_user_by_req(request) |
739 | 772 | requester_user_id = requester.user.to_string() |
740 | 773 | |
761 | 794 | self.groups_handler = hs.get_groups_local_handler() |
762 | 795 | |
763 | 796 | @_validate_group_id |
764 | async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
797 | async def on_PUT( | |
798 | self, request: SynapseRequest, group_id: str | |
799 | ) -> Tuple[int, JsonDict]: | |
765 | 800 | requester = await self.auth.get_user_by_req(request) |
766 | 801 | requester_user_id = requester.user.to_string() |
767 | 802 | |
788 | 823 | self.groups_handler = hs.get_groups_local_handler() |
789 | 824 | |
790 | 825 | @_validate_group_id |
791 | async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
826 | async def on_PUT( | |
827 | self, request: SynapseRequest, group_id: str | |
828 | ) -> Tuple[int, JsonDict]: | |
792 | 829 | requester = await self.auth.get_user_by_req(request) |
793 | 830 | requester_user_id = requester.user.to_string() |
794 | 831 | |
815 | 852 | self.store = hs.get_datastore() |
816 | 853 | |
817 | 854 | @_validate_group_id |
818 | async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: | |
855 | async def on_PUT( | |
856 | self, request: SynapseRequest, group_id: str | |
857 | ) -> Tuple[int, JsonDict]: | |
819 | 858 | requester = await self.auth.get_user_by_req(request) |
820 | 859 | requester_user_id = requester.user.to_string() |
821 | 860 | |
838 | 877 | self.store = hs.get_datastore() |
839 | 878 | self.groups_handler = hs.get_groups_local_handler() |
840 | 879 | |
841 | async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: | |
880 | async def on_GET( | |
881 | self, request: SynapseRequest, user_id: str | |
882 | ) -> Tuple[int, JsonDict]: | |
842 | 883 | await self.auth.get_user_by_req(request, allow_guest=True) |
843 | 884 | |
844 | 885 | result = await self.groups_handler.get_publicised_groups_for_user(user_id) |
858 | 899 | self.store = hs.get_datastore() |
859 | 900 | self.groups_handler = hs.get_groups_local_handler() |
860 | 901 | |
861 | async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: | |
902 | async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | |
862 | 903 | await self.auth.get_user_by_req(request, allow_guest=True) |
863 | 904 | |
864 | 905 | content = parse_json_object_from_request(request) |
880 | 921 | self.clock = hs.get_clock() |
881 | 922 | self.groups_handler = hs.get_groups_local_handler() |
882 | 923 | |
883 | async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: | |
924 | async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | |
884 | 925 | requester = await self.auth.get_user_by_req(request, allow_guest=True) |
885 | 926 | requester_user_id = requester.user.to_string() |
886 | 927 |
19 | 19 | from twisted.web.server import Request |
20 | 20 | |
21 | 21 | from synapse.http.server import DirectServeJsonResource, respond_with_json |
22 | from synapse.http.site import SynapseRequest | |
22 | 23 | |
23 | 24 | if TYPE_CHECKING: |
24 | 25 | from synapse.app.homeserver import HomeServer |
34 | 35 | self.auth = hs.get_auth() |
35 | 36 | self.limits_dict = {"m.upload.size": config.max_upload_size} |
36 | 37 | |
37 | async def _async_render_GET(self, request: Request) -> None: | |
38 | async def _async_render_GET(self, request: SynapseRequest) -> None: | |
38 | 39 | await self.auth.get_user_by_req(request) |
39 | 40 | respond_with_json(request, 200, self.limits_dict, send_cors=True) |
40 | 41 |
34 | 34 | from synapse.config._base import ConfigError |
35 | 35 | from synapse.logging.context import defer_to_thread |
36 | 36 | from synapse.metrics.background_process_metrics import run_as_background_process |
37 | from synapse.types import UserID | |
37 | 38 | from synapse.util.async_helpers import Linearizer |
38 | 39 | from synapse.util.retryutils import NotRetryingDestination |
39 | 40 | from synapse.util.stringutils import random_string |
144 | 145 | upload_name: Optional[str], |
145 | 146 | content: IO, |
146 | 147 | content_length: int, |
147 | auth_user: str, | |
148 | auth_user: UserID, | |
148 | 149 | ) -> str: |
149 | 150 | """Store uploaded content for a local user and return the mxc URL |
150 | 151 |
38 | 38 | respond_with_json_bytes, |
39 | 39 | ) |
40 | 40 | from synapse.http.servlet import parse_integer, parse_string |
41 | from synapse.http.site import SynapseRequest | |
41 | 42 | from synapse.logging.context import make_deferred_yieldable, run_in_background |
42 | 43 | from synapse.metrics.background_process_metrics import run_as_background_process |
43 | 44 | from synapse.rest.media.v1._base import get_filename_from_headers |
184 | 185 | request.setHeader(b"Allow", b"OPTIONS, GET") |
185 | 186 | respond_with_json(request, 200, {}, send_cors=True) |
186 | 187 | |
187 | async def _async_render_GET(self, request: Request) -> None: | |
188 | async def _async_render_GET(self, request: SynapseRequest) -> None: | |
188 | 189 | |
189 | 190 | # XXX: if get_user_by_req fails, what should we do in an async render? |
190 | 191 | requester = await self.auth.get_user_by_req(request) |
95 | 95 | def _resize(self, width: int, height: int) -> Image: |
96 | 96 | # 1-bit or 8-bit color palette images need converting to RGB |
97 | 97 | # otherwise they will be scaled using nearest neighbour which |
98 | # looks awful | |
99 | if self.image.mode in ["1", "P"]: | |
100 | self.image = self.image.convert("RGB") | |
98 | # looks awful. | |
99 | # | |
100 | # If the image has transparency, use RGBA instead. | |
101 | if self.image.mode in ["1", "L", "P"]: | |
102 | mode = "RGB" | |
103 | if self.image.info.get("transparency", None) is not None: | |
104 | mode = "RGBA" | |
105 | self.image = self.image.convert(mode) | |
101 | 106 | return self.image.resize((width, height), Image.ANTIALIAS) |
102 | 107 | |
103 | 108 | def scale(self, width: int, height: int, output_type: str) -> BytesIO: |
21 | 21 | from synapse.api.errors import Codes, SynapseError |
22 | 22 | from synapse.http.server import DirectServeJsonResource, respond_with_json |
23 | 23 | from synapse.http.servlet import parse_string |
24 | from synapse.http.site import SynapseRequest | |
24 | 25 | from synapse.rest.media.v1.media_storage import SpamMediaException |
25 | 26 | |
26 | 27 | if TYPE_CHECKING: |
48 | 49 | async def _async_render_OPTIONS(self, request: Request) -> None: |
49 | 50 | respond_with_json(request, 200, {}, send_cors=True) |
50 | 51 | |
51 | async def _async_render_POST(self, request: Request) -> None: | |
52 | async def _async_render_POST(self, request: SynapseRequest) -> None: | |
52 | 53 | requester = await self.auth.get_user_by_req(request) |
53 | 54 | # TODO: The checks here are a bit late. The content will have |
54 | 55 | # already been uploaded to a tmp file at this point |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | |
16 | from typing import TYPE_CHECKING | |
17 | ||
16 | 18 | from synapse.http.server import DirectServeHtmlResource |
19 | ||
20 | if TYPE_CHECKING: | |
21 | from synapse.server import HomeServer | |
17 | 22 | |
18 | 23 | |
19 | 24 | class SAML2ResponseResource(DirectServeHtmlResource): |
21 | 26 | |
22 | 27 | isLeaf = 1 |
23 | 28 | |
24 | def __init__(self, hs): | |
29 | def __init__(self, hs: "HomeServer"): | |
25 | 30 | super().__init__() |
26 | 31 | self._saml_handler = hs.get_saml_handler() |
32 | self._sso_handler = hs.get_sso_handler() | |
27 | 33 | |
28 | 34 | async def _async_render_GET(self, request): |
29 | 35 | # We're not expecting any GET request on that resource if everything goes right, |
30 | 36 | # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. |
31 | 37 | # In this case, just tell the user that something went wrong and they should |
32 | 38 | # try to authenticate again. |
33 | self._saml_handler._render_error( | |
39 | self._sso_handler.render_error( | |
34 | 40 | request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" |
35 | 41 | ) |
36 | 42 |
35 | 35 | cast, |
36 | 36 | ) |
37 | 37 | |
38 | import twisted.internet.base | |
39 | 38 | import twisted.internet.tcp |
40 | 39 | from twisted.internet import defer |
41 | 40 | from twisted.mail.smtp import sendmail |
129 | 128 | from synapse.state import StateHandler, StateResolutionHandler |
130 | 129 | from synapse.storage import Databases, DataStore, Storage |
131 | 130 | from synapse.streams.events import EventSources |
132 | from synapse.types import DomainSpecificString | |
131 | from synapse.types import DomainSpecificString, ISynapseReactor | |
133 | 132 | from synapse.util import Clock |
134 | 133 | from synapse.util.distributor import Distributor |
135 | 134 | from synapse.util.ratelimitutils import FederationRateLimiter |
290 | 289 | for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP: |
291 | 290 | getattr(self, "get_" + i + "_handler")() |
292 | 291 | |
293 | def get_reactor(self) -> twisted.internet.base.ReactorBase: | |
292 | def get_reactor(self) -> ISynapseReactor: | |
294 | 293 | """ |
295 | 294 | Fetch the Twisted reactor in use by this HomeServer. |
296 | 295 | """ |
351 | 350 | |
352 | 351 | @cache_in_self |
353 | 352 | def get_http_client_context_factory(self) -> IPolicyForHTTPS: |
354 | return ( | |
355 | InsecureInterceptableContextFactory() | |
356 | if self.config.use_insecure_ssl_client_just_for_testing_do_not_use | |
357 | else RegularPolicyForHTTPS() | |
358 | ) | |
353 | if self.config.use_insecure_ssl_client_just_for_testing_do_not_use: | |
354 | return InsecureInterceptableContextFactory() | |
355 | return RegularPolicyForHTTPS() | |
359 | 356 | |
360 | 357 | @cache_in_self |
361 | 358 | def get_simple_http_client(self) -> SimpleHttpClient: |
53 | 53 | ) # type: LruCache[str, List[Tuple[str, int]]] |
54 | 54 | |
55 | 55 | async def get_auth_chain( |
56 | self, event_ids: Collection[str], include_given: bool = False | |
56 | self, room_id: str, event_ids: Collection[str], include_given: bool = False | |
57 | 57 | ) -> List[EventBase]: |
58 | 58 | """Get auth events for given event_ids. The events *must* be state events. |
59 | 59 | |
60 | 60 | Args: |
61 | room_id: The room the event is in. | |
61 | 62 | event_ids: state events |
62 | 63 | include_given: include the given events in result |
63 | 64 | |
65 | 66 | list of events |
66 | 67 | """ |
67 | 68 | event_ids = await self.get_auth_chain_ids( |
68 | event_ids, include_given=include_given | |
69 | room_id, event_ids, include_given=include_given | |
69 | 70 | ) |
70 | 71 | return await self.get_events_as_list(event_ids) |
71 | 72 | |
72 | 73 | async def get_auth_chain_ids( |
73 | 74 | self, |
75 | room_id: str, | |
74 | 76 | event_ids: Collection[str], |
75 | 77 | include_given: bool = False, |
76 | 78 | ) -> List[str]: |
77 | 79 | """Get auth events for given event_ids. The events *must* be state events. |
78 | 80 | |
79 | 81 | Args: |
82 | room_id: The room the event is in. | |
80 | 83 | event_ids: state events |
81 | 84 | include_given: include the given events in result |
82 | 85 | |
83 | 86 | Returns: |
84 | An awaitable which resolve to a list of event_ids | |
85 | """ | |
87 | list of event_ids | |
88 | """ | |
89 | ||
90 | # Check if we have indexed the room so we can use the chain cover | |
91 | # algorithm. | |
92 | room = await self.get_room(room_id) | |
93 | if room["has_auth_chain_index"]: | |
94 | try: | |
95 | return await self.db_pool.runInteraction( | |
96 | "get_auth_chain_ids_chains", | |
97 | self._get_auth_chain_ids_using_cover_index_txn, | |
98 | room_id, | |
99 | event_ids, | |
100 | include_given, | |
101 | ) | |
102 | except _NoChainCoverIndex: | |
103 | # For whatever reason we don't actually have a chain cover index | |
104 | # for the events in question, so we fall back to the old method. | |
105 | pass | |
106 | ||
86 | 107 | return await self.db_pool.runInteraction( |
87 | 108 | "get_auth_chain_ids", |
88 | 109 | self._get_auth_chain_ids_txn, |
90 | 111 | include_given, |
91 | 112 | ) |
92 | 113 | |
114 | def _get_auth_chain_ids_using_cover_index_txn( | |
115 | self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool | |
116 | ) -> List[str]: | |
117 | """Calculates the auth chain IDs using the chain index.""" | |
118 | ||
119 | # First we look up the chain ID/sequence numbers for the given events. | |
120 | ||
121 | initial_events = set(event_ids) | |
122 | ||
123 | # All the events that we've found that are reachable from the events. | |
124 | seen_events = set() # type: Set[str] | |
125 | ||
126 | # A map from chain ID to max sequence number of the given events. | |
127 | event_chains = {} # type: Dict[int, int] | |
128 | ||
129 | sql = """ | |
130 | SELECT event_id, chain_id, sequence_number | |
131 | FROM event_auth_chains | |
132 | WHERE %s | |
133 | """ | |
134 | for batch in batch_iter(initial_events, 1000): | |
135 | clause, args = make_in_list_sql_clause( | |
136 | txn.database_engine, "event_id", batch | |
137 | ) | |
138 | txn.execute(sql % (clause,), args) | |
139 | ||
140 | for event_id, chain_id, sequence_number in txn: | |
141 | seen_events.add(event_id) | |
142 | event_chains[chain_id] = max( | |
143 | sequence_number, event_chains.get(chain_id, 0) | |
144 | ) | |
145 | ||
146 | # Check that we actually have a chain ID for all the events. | |
147 | events_missing_chain_info = initial_events.difference(seen_events) | |
148 | if events_missing_chain_info: | |
149 | # This can happen due to e.g. downgrade/upgrade of the server. We | |
150 | # raise an exception and fall back to the previous algorithm. | |
151 | logger.info( | |
152 | "Unexpectedly found that events don't have chain IDs in room %s: %s", | |
153 | room_id, | |
154 | events_missing_chain_info, | |
155 | ) | |
156 | raise _NoChainCoverIndex(room_id) | |
157 | ||
158 | # Now we look up all links for the chains we have, adding chains that | |
159 | # are reachable from any event. | |
160 | sql = """ | |
161 | SELECT | |
162 | origin_chain_id, origin_sequence_number, | |
163 | target_chain_id, target_sequence_number | |
164 | FROM event_auth_chain_links | |
165 | WHERE %s | |
166 | """ | |
167 | ||
168 | # A map from chain ID to max sequence number *reachable* from any event ID. | |
169 | chains = {} # type: Dict[int, int] | |
170 | ||
171 | # Add all linked chains reachable from initial set of chains. | |
172 | for batch in batch_iter(event_chains, 1000): | |
173 | clause, args = make_in_list_sql_clause( | |
174 | txn.database_engine, "origin_chain_id", batch | |
175 | ) | |
176 | txn.execute(sql % (clause,), args) | |
177 | ||
178 | for ( | |
179 | origin_chain_id, | |
180 | origin_sequence_number, | |
181 | target_chain_id, | |
182 | target_sequence_number, | |
183 | ) in txn: | |
184 | # chains are only reachable if the origin sequence number of | |
185 | # the link is less than the max sequence number in the | |
186 | # origin chain. | |
187 | if origin_sequence_number <= event_chains.get(origin_chain_id, 0): | |
188 | chains[target_chain_id] = max( | |
189 | target_sequence_number, | |
190 | chains.get(target_chain_id, 0), | |
191 | ) | |
192 | ||
193 | # Add the initial set of chains, excluding the sequence corresponding to | |
194 | # initial event. | |
195 | for chain_id, seq_no in event_chains.items(): | |
196 | chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0)) | |
197 | ||
198 | # Now for each chain we figure out the maximum sequence number reachable | |
199 | # from *any* event ID. Events with a sequence less than that are in the | |
200 | # auth chain. | |
201 | if include_given: | |
202 | results = initial_events | |
203 | else: | |
204 | results = set() | |
205 | ||
206 | if isinstance(self.database_engine, PostgresEngine): | |
207 | # We can use `execute_values` to efficiently fetch the gaps when | |
208 | # using postgres. | |
209 | sql = """ | |
210 | SELECT event_id | |
211 | FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq) | |
212 | WHERE | |
213 | c.chain_id = l.chain_id | |
214 | AND sequence_number <= max_seq | |
215 | """ | |
216 | ||
217 | rows = txn.execute_values(sql, chains.items()) | |
218 | results.update(r for r, in rows) | |
219 | else: | |
220 | # For SQLite we just fall back to doing a noddy for loop. | |
221 | sql = """ | |
222 | SELECT event_id FROM event_auth_chains | |
223 | WHERE chain_id = ? AND sequence_number <= ? | |
224 | """ | |
225 | for chain_id, max_no in chains.items(): | |
226 | txn.execute(sql, (chain_id, max_no)) | |
227 | results.update(r for r, in txn) | |
228 | ||
229 | return list(results) | |
230 | ||
93 | 231 | def _get_auth_chain_ids_txn( |
94 | 232 | self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool |
95 | 233 | ) -> List[str]: |
234 | """Calculates the auth chain IDs. | |
235 | ||
236 | This is used when we don't have a cover index for the room. | |
237 | """ | |
96 | 238 | if include_given: |
97 | 239 | results = set(event_ids) |
98 | 240 | else: |
132 | 132 | self.db_pool.updates.register_background_update_handler( |
133 | 133 | "chain_cover", |
134 | 134 | self._chain_cover_index, |
135 | ) | |
136 | ||
137 | self.db_pool.updates.register_background_update_handler( | |
138 | "purged_chain_cover", | |
139 | self._purged_chain_cover_index, | |
135 | 140 | ) |
136 | 141 | |
137 | 142 | async def _background_reindex_fields_sender(self, progress, batch_size): |
931 | 936 | processed_count=count, |
932 | 937 | finished_room_map=finished_rooms, |
933 | 938 | ) |
939 | ||
940 | async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int: | |
941 | """ | |
942 | A background updates that iterates over the chain cover and deletes the | |
943 | chain cover for events that have been purged. | |
944 | ||
945 | This may be due to fully purging a room or via setting a retention policy. | |
946 | """ | |
947 | current_event_id = progress.get("current_event_id", "") | |
948 | ||
949 | def purged_chain_cover_txn(txn) -> int: | |
950 | # The event ID from events will be null if the chain ID / sequence | |
951 | # number points to a purged event. | |
952 | sql = """ | |
953 | SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL | |
954 | FROM event_auth_chains | |
955 | LEFT JOIN events AS e USING (event_id) | |
956 | WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ? | |
957 | """ | |
958 | txn.execute(sql, (current_event_id, batch_size)) | |
959 | ||
960 | rows = txn.fetchall() | |
961 | if not rows: | |
962 | return 0 | |
963 | ||
964 | # The event IDs and chain IDs / sequence numbers where the event has | |
965 | # been purged. | |
966 | unreferenced_event_ids = [] | |
967 | unreferenced_chain_id_tuples = [] | |
968 | event_id = "" | |
969 | for event_id, chain_id, sequence_number, has_event in rows: | |
970 | if not has_event: | |
971 | unreferenced_event_ids.append((event_id,)) | |
972 | unreferenced_chain_id_tuples.append((chain_id, sequence_number)) | |
973 | ||
974 | # Delete the unreferenced auth chains from event_auth_chain_links and | |
975 | # event_auth_chains. | |
976 | txn.executemany( | |
977 | """ | |
978 | DELETE FROM event_auth_chains WHERE event_id = ? | |
979 | """, | |
980 | unreferenced_event_ids, | |
981 | ) | |
982 | # We should also delete matching target_*, but there is no index on | |
983 | # target_chain_id. Hopefully any purged events are due to a room | |
984 | # being fully purged and they will be removed from the origin_* | |
985 | # searches. | |
986 | txn.executemany( | |
987 | """ | |
988 | DELETE FROM event_auth_chain_links WHERE | |
989 | origin_chain_id = ? AND origin_sequence_number = ? | |
990 | """, | |
991 | unreferenced_chain_id_tuples, | |
992 | ) | |
993 | ||
994 | progress = { | |
995 | "current_event_id": event_id, | |
996 | } | |
997 | ||
998 | self.db_pool.updates._background_update_progress_txn( | |
999 | txn, "purged_chain_cover", progress | |
1000 | ) | |
1001 | ||
1002 | return len(rows) | |
1003 | ||
1004 | result = await self.db_pool.runInteraction( | |
1005 | "_purged_chain_cover_index", | |
1006 | purged_chain_cover_txn, | |
1007 | ) | |
1008 | ||
1009 | if not result: | |
1010 | await self.db_pool.updates._end_background_update("purged_chain_cover") | |
1011 | ||
1012 | return result |
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | import itertools | |
14 | ||
15 | 15 | import logging |
16 | 16 | import threading |
17 | 17 | from collections import namedtuple |
1043 | 1043 | Returns: |
1044 | 1044 | set[str]: The events we have already seen. |
1045 | 1045 | """ |
1046 | results = set() | |
1046 | # if the event cache contains the event, obviously we've seen it. | |
1047 | results = {x for x in event_ids if self._get_event_cache.contains(x)} | |
1047 | 1048 | |
1048 | 1049 | def have_seen_events_txn(txn, chunk): |
1049 | 1050 | sql = "SELECT event_id FROM events as e WHERE " |
1051 | 1052 | txn.database_engine, "e.event_id", chunk |
1052 | 1053 | ) |
1053 | 1054 | txn.execute(sql + clause, args) |
1054 | for (event_id,) in txn: | |
1055 | results.add(event_id) | |
1056 | ||
1057 | # break the input up into chunks of 100 | |
1058 | input_iterator = iter(event_ids) | |
1059 | for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): | |
1055 | results.update(row[0] for row in txn) | |
1056 | ||
1057 | for chunk in batch_iter((x for x in event_ids if x not in results), 100): | |
1060 | 1058 | await self.db_pool.runInteraction( |
1061 | 1059 | "have_seen_events", have_seen_events_txn, chunk |
1062 | 1060 | ) |
330 | 330 | txn.executemany( |
331 | 331 | """ |
332 | 332 | DELETE FROM event_auth_chain_links WHERE |
333 | (origin_chain_id = ? AND origin_sequence_number = ?) OR | |
334 | (target_chain_id = ? AND target_sequence_number = ?) | |
333 | origin_chain_id = ? AND origin_sequence_number = ? | |
335 | 334 | """, |
336 | ( | |
337 | (chain_id, seq_num, chain_id, seq_num) | |
338 | for (chain_id, seq_num) in referenced_chain_id_tuples | |
339 | ), | |
335 | referenced_chain_id_tuples, | |
340 | 336 | ) |
341 | 337 | |
342 | 338 | # Now we delete tables which lack an index on room_id but have one on event_id |
15 | 15 | # limitations under the License. |
16 | 16 | import logging |
17 | 17 | import re |
18 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple | |
18 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | |
19 | 19 | |
20 | 20 | import attr |
21 | 21 | |
1509 | 1509 | async def user_delete_access_tokens( |
1510 | 1510 | self, |
1511 | 1511 | user_id: str, |
1512 | except_token_id: Optional[str] = None, | |
1512 | except_token_id: Optional[int] = None, | |
1513 | 1513 | device_id: Optional[str] = None, |
1514 | 1514 | ) -> List[Tuple[str, int, Optional[str]]]: |
1515 | 1515 | """ |
1532 | 1532 | |
1533 | 1533 | items = keyvalues.items() |
1534 | 1534 | where_clause = " AND ".join(k + " = ?" for k, _ in items) |
1535 | values = [v for _, v in items] | |
1535 | values = [v for _, v in items] # type: List[Union[str, int]] | |
1536 | 1536 | if except_token_id: |
1537 | 1537 | where_clause += " AND id != ?" |
1538 | 1538 | values.append(except_token_id) |
0 | /* Copyright 2021 The Matrix.org Foundation C.I.C | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | INSERT INTO background_updates (ordering, update_name, progress_json) VALUES | |
16 | (5910, 'purged_chain_cover', '{}'); |
349 | 349 | |
350 | 350 | self.db_pool.simple_upsert_many_txn( |
351 | 351 | txn, |
352 | "destination_rooms", | |
353 | ["destination", "room_id"], | |
354 | rows, | |
355 | ["stream_ordering"], | |
356 | [(stream_ordering,)] * len(rows), | |
352 | table="destination_rooms", | |
353 | key_names=("destination", "room_id"), | |
354 | key_values=rows, | |
355 | value_names=["stream_ordering"], | |
356 | value_values=[(stream_ordering,)] * len(rows), | |
357 | 357 | ) |
358 | 358 | |
359 | 359 | async def get_destination_last_successful_stream_ordering( |
34 | 34 | import attr |
35 | 35 | from signedjson.key import decode_verify_key_bytes |
36 | 36 | from unpaddedbase64 import decode_base64 |
37 | from zope.interface import Interface | |
38 | ||
39 | from twisted.internet.interfaces import ( | |
40 | IReactorCore, | |
41 | IReactorPluggableNameResolver, | |
42 | IReactorTCP, | |
43 | IReactorTime, | |
44 | ) | |
37 | 45 | |
38 | 46 | from synapse.api.errors import Codes, SynapseError |
39 | 47 | from synapse.util.stringutils import parse_and_validate_server_name |
66 | 74 | JsonDict = Dict[str, Any] |
67 | 75 | |
68 | 76 | |
69 | class Requester( | |
70 | namedtuple( | |
71 | "Requester", | |
72 | [ | |
73 | "user", | |
74 | "access_token_id", | |
75 | "is_guest", | |
76 | "shadow_banned", | |
77 | "device_id", | |
78 | "app_service", | |
79 | "authenticated_entity", | |
80 | ], | |
81 | ) | |
77 | # Note that this seems to require inheriting *directly* from Interface in order | |
78 | # for mypy-zope to realize it is an interface. | |
79 | class ISynapseReactor( | |
80 | IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface | |
82 | 81 | ): |
82 | """The interfaces necessary for Synapse to function.""" | |
83 | ||
84 | ||
85 | @attr.s(frozen=True, slots=True) | |
86 | class Requester: | |
83 | 87 | """ |
84 | 88 | Represents the user making a request |
85 | 89 | |
86 | 90 | Attributes: |
87 | user (UserID): id of the user making the request | |
88 | access_token_id (int|None): *ID* of the access token used for this | |
91 | user: id of the user making the request | |
92 | access_token_id: *ID* of the access token used for this | |
89 | 93 | request, or None if it came via the appservice API or similar |
90 | is_guest (bool): True if the user making this request is a guest user | |
91 | shadow_banned (bool): True if the user making this request has been shadow-banned. | |
92 | device_id (str|None): device_id which was set at authentication time | |
93 | app_service (ApplicationService|None): the AS requesting on behalf of the user | |
94 | """ | |
94 | is_guest: True if the user making this request is a guest user | |
95 | shadow_banned: True if the user making this request has been shadow-banned. | |
96 | device_id: device_id which was set at authentication time | |
97 | app_service: the AS requesting on behalf of the user | |
98 | authenticated_entity: The entity that authenticated when making the request. | |
99 | This is different to the user_id when an admin user or the server is | |
100 | "puppeting" the user. | |
101 | """ | |
102 | ||
103 | user = attr.ib(type="UserID") | |
104 | access_token_id = attr.ib(type=Optional[int]) | |
105 | is_guest = attr.ib(type=bool) | |
106 | shadow_banned = attr.ib(type=bool) | |
107 | device_id = attr.ib(type=Optional[str]) | |
108 | app_service = attr.ib(type=Optional["ApplicationService"]) | |
109 | authenticated_entity = attr.ib(type=str) | |
95 | 110 | |
96 | 111 | def serialize(self): |
97 | 112 | """Converts self to a type that can be serialized as JSON, and then |
140 | 155 | def create_requester( |
141 | 156 | user_id: Union[str, "UserID"], |
142 | 157 | access_token_id: Optional[int] = None, |
143 | is_guest: Optional[bool] = False, | |
144 | shadow_banned: Optional[bool] = False, | |
158 | is_guest: bool = False, | |
159 | shadow_banned: bool = False, | |
145 | 160 | device_id: Optional[str] = None, |
146 | 161 | app_service: Optional["ApplicationService"] = None, |
147 | 162 | authenticated_entity: Optional[str] = None, |
148 | ): | |
163 | ) -> Requester: | |
149 | 164 | """ |
150 | 165 | Create a new ``Requester`` object |
151 | 166 | |
152 | 167 | Args: |
153 | user_id (str|UserID): id of the user making the request | |
154 | access_token_id (int|None): *ID* of the access token used for this | |
168 | user_id: id of the user making the request | |
169 | access_token_id: *ID* of the access token used for this | |
155 | 170 | request, or None if it came via the appservice API or similar |
156 | is_guest (bool): True if the user making this request is a guest user | |
157 | shadow_banned (bool): True if the user making this request is shadow-banned. | |
158 | device_id (str|None): device_id which was set at authentication time | |
159 | app_service (ApplicationService|None): the AS requesting on behalf of the user | |
171 | is_guest: True if the user making this request is a guest user | |
172 | shadow_banned: True if the user making this request is shadow-banned. | |
173 | device_id: device_id which was set at authentication time | |
174 | app_service: the AS requesting on behalf of the user | |
160 | 175 | authenticated_entity: The entity that authenticated when making the request. |
161 | 176 | This is different to the user_id when an admin user or the server is |
162 | 177 | "puppeting" the user. |
75 | 75 | def callback(r): |
76 | 76 | object.__setattr__(self, "_result", (True, r)) |
77 | 77 | while self._observers: |
78 | observer = self._observers.pop() | |
78 | 79 | try: |
79 | # TODO: Handle errors here. | |
80 | self._observers.pop().callback(r) | |
81 | except Exception: | |
82 | pass | |
80 | observer.callback(r) | |
81 | except Exception as e: | |
82 | logger.exception( | |
83 | "%r threw an exception on .callback(%r), ignoring...", | |
84 | observer, | |
85 | r, | |
86 | exc_info=e, | |
87 | ) | |
83 | 88 | return r |
84 | 89 | |
85 | 90 | def errback(f): |
89 | 94 | # traces when we `await` on one of the observer deferreds. |
90 | 95 | f.value.__failure__ = f |
91 | 96 | |
97 | observer = self._observers.pop() | |
92 | 98 | try: |
93 | # TODO: Handle errors here. | |
94 | self._observers.pop().errback(f) | |
95 | except Exception: | |
96 | pass | |
99 | observer.errback(f) | |
100 | except Exception as e: | |
101 | logger.exception( | |
102 | "%r threw an exception on .errback(%r), ignoring...", | |
103 | observer, | |
104 | f, | |
105 | exc_info=e, | |
106 | ) | |
97 | 107 | |
98 | 108 | if consumeErrors: |
99 | 109 | return None |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import logging |
15 | from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar | |
15 | from typing import Any, Callable, Dict, Generic, Optional, TypeVar | |
16 | 16 | |
17 | 17 | from twisted.internet import defer |
18 | 18 | |
19 | 19 | from synapse.logging.context import make_deferred_yieldable, run_in_background |
20 | from synapse.util import Clock | |
20 | 21 | from synapse.util.async_helpers import ObservableDeferred |
21 | 22 | from synapse.util.caches import register_cache |
22 | ||
23 | if TYPE_CHECKING: | |
24 | from synapse.app.homeserver import HomeServer | |
25 | 23 | |
26 | 24 | logger = logging.getLogger(__name__) |
27 | 25 | |
36 | 34 | used rather than trying to compute a new response. |
37 | 35 | """ |
38 | 36 | |
39 | def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): | |
37 | def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): | |
40 | 38 | # Requests that haven't finished yet. |
41 | 39 | self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] |
42 | 40 | |
43 | self.clock = hs.get_clock() | |
41 | self.clock = clock | |
44 | 42 | self.timeout_sec = timeout_ms / 1000.0 |
45 | 43 | |
46 | 44 | self._name = name |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2020 Quentin Gliech | |
2 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
3 | # | |
4 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
5 | # you may not use this file except in compliance with the License. | |
6 | # You may obtain a copy of the License at | |
7 | # | |
8 | # http://www.apache.org/licenses/LICENSE-2.0 | |
9 | # | |
10 | # Unless required by applicable law or agreed to in writing, software | |
11 | # distributed under the License is distributed on an "AS IS" BASIS, | |
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
13 | # See the License for the specific language governing permissions and | |
14 | # limitations under the License. | |
15 | ||
16 | """Utilities for manipulating macaroons""" | |
17 | ||
18 | from typing import Callable, Optional | |
19 | ||
20 | import pymacaroons | |
21 | from pymacaroons.exceptions import MacaroonVerificationFailedException | |
22 | ||
23 | ||
24 | def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: | |
25 | """Extracts a caveat value from a macaroon token. | |
26 | ||
27 | Checks that there is exactly one caveat of the form "key = <val>" in the macaroon, | |
28 | and returns the extracted value. | |
29 | ||
30 | Args: | |
31 | macaroon: the token | |
32 | key: the key of the caveat to extract | |
33 | ||
34 | Returns: | |
35 | The extracted value | |
36 | ||
37 | Raises: | |
38 | MacaroonVerificationFailedException: if there are conflicting values for the | |
39 | caveat in the macaroon, or if the caveat was not found in the macaroon. | |
40 | """ | |
41 | prefix = key + " = " | |
42 | result = None # type: Optional[str] | |
43 | for caveat in macaroon.caveats: | |
44 | if not caveat.caveat_id.startswith(prefix): | |
45 | continue | |
46 | ||
47 | val = caveat.caveat_id[len(prefix) :] | |
48 | ||
49 | if result is None: | |
50 | # first time we found this caveat: record the value | |
51 | result = val | |
52 | elif val != result: | |
53 | # on subsequent occurrences, raise if the value is different. | |
54 | raise MacaroonVerificationFailedException( | |
55 | "Conflicting values for caveat " + key | |
56 | ) | |
57 | ||
58 | if result is not None: | |
59 | return result | |
60 | ||
61 | # If the caveat is not there, we raise a MacaroonVerificationFailedException. | |
62 | # Note that it is insecure to generate a macaroon without all the caveats you | |
63 | # might need (because there is nothing stopping people from adding extra caveats), | |
64 | # so if the caveat isn't there, something odd must be going on. | |
65 | raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,)) | |
66 | ||
67 | ||
68 | def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None: | |
69 | """Make a macaroon verifier which accepts 'time' caveats | |
70 | ||
71 | Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to | |
72 | the given macaroon verifier. | |
73 | ||
74 | Args: | |
75 | v: the macaroon verifier | |
76 | get_time_ms: a callable which will return the timestamp after which the caveat | |
77 | should be considered expired. Normally the current time. | |
78 | """ | |
79 | ||
80 | def verify_expiry_caveat(caveat: str): | |
81 | time_msec = get_time_ms() | |
82 | prefix = "time < " | |
83 | if not caveat.startswith(prefix): | |
84 | return False | |
85 | expiry = int(caveat[len(prefix) :]) | |
86 | return time_msec < expiry | |
87 | ||
88 | v.satisfy_general(verify_expiry_caveat) |
6 | 6 | from synapse.federation.units import Edu |
7 | 7 | from synapse.rest import admin |
8 | 8 | from synapse.rest.client.v1 import login, room |
9 | from synapse.util.retryutils import NotRetryingDestination | |
9 | 10 | |
10 | 11 | from tests.test_utils import event_injection, make_awaitable |
11 | 12 | from tests.unittest import FederatingHomeserverTestCase, override_config |
48 | 49 | else: |
49 | 50 | data = json_cb() |
50 | 51 | self.failed_pdus.extend(data["pdus"]) |
51 | raise IOError("Failed to connect because this is a test!") | |
52 | raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination) | |
52 | 53 | |
53 | 54 | def get_destination_room(self, room: str, destination: str = "host2") -> dict: |
54 | 55 | """ |
0 | -----BEGIN PRIVATE KEY----- | |
1 | MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp | |
2 | Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA | |
3 | P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW | |
4 | -----END PRIVATE KEY----- |
0 | -----BEGIN PUBLIC KEY----- | |
1 | MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi | |
2 | QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg== | |
3 | -----END PUBLIC KEY----- |
67 | 67 | v.verify(macaroon, self.hs.config.macaroon_secret_key) |
68 | 68 | |
69 | 69 | def test_short_term_login_token_gives_user_id(self): |
70 | token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) | |
71 | user_id = self.get_success( | |
72 | self.auth_handler.validate_short_term_login_token_and_get_user_id(token) | |
73 | ) | |
74 | self.assertEqual("a_user", user_id) | |
70 | token = self.macaroon_generator.generate_short_term_login_token( | |
71 | "a_user", "", 5000 | |
72 | ) | |
73 | res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) | |
74 | self.assertEqual("a_user", res.user_id) | |
75 | self.assertEqual("", res.auth_provider_id) | |
75 | 76 | |
76 | 77 | # when we advance the clock, the token should be rejected |
77 | 78 | self.reactor.advance(6) |
78 | 79 | self.get_failure( |
79 | self.auth_handler.validate_short_term_login_token_and_get_user_id(token), | |
80 | self.auth_handler.validate_short_term_login_token(token), | |
80 | 81 | AuthError, |
81 | 82 | ) |
82 | 83 | |
84 | def test_short_term_login_token_gives_auth_provider(self): | |
85 | token = self.macaroon_generator.generate_short_term_login_token( | |
86 | "a_user", auth_provider_id="my_idp" | |
87 | ) | |
88 | res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) | |
89 | self.assertEqual("a_user", res.user_id) | |
90 | self.assertEqual("my_idp", res.auth_provider_id) | |
91 | ||
83 | 92 | def test_short_term_login_token_cannot_replace_user_id(self): |
84 | token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) | |
93 | token = self.macaroon_generator.generate_short_term_login_token( | |
94 | "a_user", "", 5000 | |
95 | ) | |
85 | 96 | macaroon = pymacaroons.Macaroon.deserialize(token) |
86 | 97 | |
87 | user_id = self.get_success( | |
88 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
89 | macaroon.serialize() | |
90 | ) | |
91 | ) | |
92 | self.assertEqual("a_user", user_id) | |
98 | res = self.get_success( | |
99 | self.auth_handler.validate_short_term_login_token(macaroon.serialize()) | |
100 | ) | |
101 | self.assertEqual("a_user", res.user_id) | |
93 | 102 | |
94 | 103 | # add another "user_id" caveat, which might allow us to override the |
95 | 104 | # user_id. |
96 | 105 | macaroon.add_first_party_caveat("user_id = b_user") |
97 | 106 | |
98 | 107 | self.get_failure( |
99 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
100 | macaroon.serialize() | |
101 | ), | |
108 | self.auth_handler.validate_short_term_login_token(macaroon.serialize()), | |
102 | 109 | AuthError, |
103 | 110 | ) |
104 | 111 | |
112 | 119 | ) |
113 | 120 | |
114 | 121 | self.get_success( |
115 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
122 | self.auth_handler.validate_short_term_login_token( | |
116 | 123 | self._get_macaroon().serialize() |
117 | 124 | ) |
118 | 125 | ) |
134 | 141 | return_value=make_awaitable(self.large_number_of_users) |
135 | 142 | ) |
136 | 143 | self.get_failure( |
137 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
144 | self.auth_handler.validate_short_term_login_token( | |
138 | 145 | self._get_macaroon().serialize() |
139 | 146 | ), |
140 | 147 | ResourceLimitError, |
158 | 165 | ResourceLimitError, |
159 | 166 | ) |
160 | 167 | self.get_failure( |
161 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
168 | self.auth_handler.validate_short_term_login_token( | |
162 | 169 | self._get_macaroon().serialize() |
163 | 170 | ), |
164 | 171 | ResourceLimitError, |
174 | 181 | ) |
175 | 182 | ) |
176 | 183 | self.get_success( |
177 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
184 | self.auth_handler.validate_short_term_login_token( | |
178 | 185 | self._get_macaroon().serialize() |
179 | 186 | ) |
180 | 187 | ) |
196 | 203 | return_value=make_awaitable(self.small_number_of_users) |
197 | 204 | ) |
198 | 205 | self.get_success( |
199 | self.auth_handler.validate_short_term_login_token_and_get_user_id( | |
206 | self.auth_handler.validate_short_term_login_token( | |
200 | 207 | self._get_macaroon().serialize() |
201 | 208 | ) |
202 | 209 | ) |
203 | 210 | |
204 | 211 | def _get_macaroon(self): |
205 | token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000) | |
212 | token = self.macaroon_generator.generate_short_term_login_token( | |
213 | "user_a", "", 5000 | |
214 | ) | |
206 | 215 | return pymacaroons.Macaroon.deserialize(token) |
65 | 65 | |
66 | 66 | # check that the auth handler got called as expected |
67 | 67 | auth_handler.complete_sso_login.assert_called_once_with( |
68 | "@test_user:test", request, "redirect_uri", None, new_user=True | |
68 | "@test_user:test", "cas", request, "redirect_uri", None, new_user=True | |
69 | 69 | ) |
70 | 70 | |
71 | 71 | def test_map_cas_user_to_existing_user(self): |
88 | 88 | |
89 | 89 | # check that the auth handler got called as expected |
90 | 90 | auth_handler.complete_sso_login.assert_called_once_with( |
91 | "@test_user:test", request, "redirect_uri", None, new_user=False | |
91 | "@test_user:test", "cas", request, "redirect_uri", None, new_user=False | |
92 | 92 | ) |
93 | 93 | |
94 | 94 | # Subsequent calls should map to the same mxid. |
97 | 97 | self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") |
98 | 98 | ) |
99 | 99 | auth_handler.complete_sso_login.assert_called_once_with( |
100 | "@test_user:test", request, "redirect_uri", None, new_user=False | |
100 | "@test_user:test", "cas", request, "redirect_uri", None, new_user=False | |
101 | 101 | ) |
102 | 102 | |
103 | 103 | def test_map_cas_user_to_invalid_localpart(self): |
115 | 115 | |
116 | 116 | # check that the auth handler got called as expected |
117 | 117 | auth_handler.complete_sso_login.assert_called_once_with( |
118 | "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True | |
118 | "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True | |
119 | 119 | ) |
120 | 120 | |
121 | 121 | @override_config( |
159 | 159 | |
160 | 160 | # check that the auth handler got called as expected |
161 | 161 | auth_handler.complete_sso_login.assert_called_once_with( |
162 | "@test_user:test", request, "redirect_uri", None, new_user=True | |
162 | "@test_user:test", "cas", request, "redirect_uri", None, new_user=True | |
163 | 163 | ) |
164 | 164 | |
165 | 165 |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import json |
15 | from typing import Optional | |
15 | import os | |
16 | 16 | from urllib.parse import parse_qs, urlparse |
17 | 17 | |
18 | 18 | from mock import ANY, Mock, patch |
22 | 22 | from synapse.handlers.sso import MappingException |
23 | 23 | from synapse.server import HomeServer |
24 | 24 | from synapse.types import UserID |
25 | from synapse.util.macaroons import get_value_from_macaroon | |
25 | 26 | |
26 | 27 | from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock |
27 | 28 | from tests.unittest import HomeserverTestCase, override_config |
49 | 50 | JWKS_URI = ISSUER + ".well-known/jwks.json" |
50 | 51 | |
51 | 52 | # config for common cases |
52 | COMMON_CONFIG = { | |
53 | DEFAULT_CONFIG = { | |
54 | "enabled": True, | |
55 | "client_id": CLIENT_ID, | |
56 | "client_secret": CLIENT_SECRET, | |
57 | "issuer": ISSUER, | |
58 | "scopes": SCOPES, | |
59 | "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, | |
60 | } | |
61 | ||
62 | # extends the default config with explicit OAuth2 endpoints instead of using discovery | |
63 | EXPLICIT_ENDPOINT_CONFIG = { | |
64 | **DEFAULT_CONFIG, | |
53 | 65 | "discover": False, |
54 | 66 | "authorization_endpoint": AUTHORIZATION_ENDPOINT, |
55 | 67 | "token_endpoint": TOKEN_ENDPOINT, |
106 | 118 | return {"keys": []} |
107 | 119 | |
108 | 120 | |
121 | def _key_file_path() -> str: | |
122 | """path to a file containing the private half of a test key""" | |
123 | ||
124 | # this key was generated with: | |
125 | # openssl ecparam -name prime256v1 -genkey -noout | | |
126 | # openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8 | |
127 | # | |
128 | # we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because | |
129 | # that's what Apple use, and we want to be sure that we work with Apple's keys. | |
130 | # | |
131 | # (For the record: both PKCS8 and SEC-1 specify (different) ways of representing | |
132 | # keys using ASN.1. Both are then typically formatted using PEM, which says: use the | |
133 | # base64-encoded DER encoding of ASN.1, with headers and footers. But we don't | |
134 | # really need to care about any of that.) | |
135 | return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8") | |
136 | ||
137 | ||
138 | def _public_key_file_path() -> str: | |
139 | """path to a file containing the public half of a test key""" | |
140 | # this was generated with: | |
141 | # openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem | |
142 | # | |
143 | # See above about where oidc_test_key.p8 came from | |
144 | return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem") | |
145 | ||
146 | ||
109 | 147 | class OidcHandlerTestCase(HomeserverTestCase): |
110 | 148 | if not HAS_OIDC: |
111 | 149 | skip = "requires OIDC" |
113 | 151 | def default_config(self): |
114 | 152 | config = super().default_config() |
115 | 153 | config["public_baseurl"] = BASE_URL |
116 | oidc_config = { | |
117 | "enabled": True, | |
118 | "client_id": CLIENT_ID, | |
119 | "client_secret": CLIENT_SECRET, | |
120 | "issuer": ISSUER, | |
121 | "scopes": SCOPES, | |
122 | "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, | |
123 | } | |
124 | ||
125 | # Update this config with what's in the default config so that | |
126 | # override_config works as expected. | |
127 | oidc_config.update(config.get("oidc_config", {})) | |
128 | config["oidc_config"] = oidc_config | |
129 | ||
130 | 154 | return config |
131 | 155 | |
132 | 156 | def make_homeserver(self, reactor, clock): |
169 | 193 | self.render_error.reset_mock() |
170 | 194 | return args |
171 | 195 | |
196 | @override_config({"oidc_config": DEFAULT_CONFIG}) | |
172 | 197 | def test_config(self): |
173 | 198 | """Basic config correctly sets up the callback URL and client auth correctly.""" |
174 | 199 | self.assertEqual(self.provider._callback_url, CALLBACK_URL) |
175 | 200 | self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) |
176 | 201 | self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) |
177 | 202 | |
178 | @override_config({"oidc_config": {"discover": True}}) | |
203 | @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}}) | |
179 | 204 | def test_discovery(self): |
180 | 205 | """The handler should discover the endpoints from OIDC discovery document.""" |
181 | 206 | # This would throw if some metadata were invalid |
194 | 219 | self.get_success(self.provider.load_metadata()) |
195 | 220 | self.http_client.get_json.assert_not_called() |
196 | 221 | |
197 | @override_config({"oidc_config": COMMON_CONFIG}) | |
222 | @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) | |
198 | 223 | def test_no_discovery(self): |
199 | 224 | """When discovery is disabled, it should not try to load from discovery document.""" |
200 | 225 | self.get_success(self.provider.load_metadata()) |
201 | 226 | self.http_client.get_json.assert_not_called() |
202 | 227 | |
203 | @override_config({"oidc_config": COMMON_CONFIG}) | |
228 | @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) | |
204 | 229 | def test_load_jwks(self): |
205 | 230 | """JWKS loading is done once (then cached) if used.""" |
206 | 231 | jwks = self.get_success(self.provider.load_jwks()) |
235 | 260 | self.http_client.get_json.assert_not_called() |
236 | 261 | self.assertEqual(jwks, {"keys": []}) |
237 | 262 | |
263 | @override_config({"oidc_config": DEFAULT_CONFIG}) | |
238 | 264 | def test_validate_config(self): |
239 | 265 | """Provider metadatas are extensively validated.""" |
240 | 266 | h = self.provider |
317 | 343 | # Shouldn't raise with a valid userinfo, even without jwks |
318 | 344 | force_load_metadata() |
319 | 345 | |
320 | @override_config({"oidc_config": {"skip_verification": True}}) | |
346 | @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}}) | |
321 | 347 | def test_skip_verification(self): |
322 | 348 | """Provider metadata validation can be disabled by config.""" |
323 | 349 | with self.metadata_edit({"issuer": "http://insecure"}): |
324 | 350 | # This should not throw |
325 | 351 | get_awaitable_result(self.provider.load_metadata()) |
326 | 352 | |
353 | @override_config({"oidc_config": DEFAULT_CONFIG}) | |
327 | 354 | def test_redirect_request(self): |
328 | 355 | """The redirect request has the right arguments & generates a valid session cookie.""" |
329 | 356 | req = Mock(spec=["cookies"]) |
359 | 386 | self.assertEqual(name, b"oidc_session") |
360 | 387 | |
361 | 388 | macaroon = pymacaroons.Macaroon.deserialize(cookie) |
362 | state = self.handler._token_generator._get_value_from_macaroon( | |
363 | macaroon, "state" | |
364 | ) | |
365 | nonce = self.handler._token_generator._get_value_from_macaroon( | |
366 | macaroon, "nonce" | |
367 | ) | |
368 | redirect = self.handler._token_generator._get_value_from_macaroon( | |
369 | macaroon, "client_redirect_url" | |
370 | ) | |
389 | state = get_value_from_macaroon(macaroon, "state") | |
390 | nonce = get_value_from_macaroon(macaroon, "nonce") | |
391 | redirect = get_value_from_macaroon(macaroon, "client_redirect_url") | |
371 | 392 | |
372 | 393 | self.assertEqual(params["state"], [state]) |
373 | 394 | self.assertEqual(params["nonce"], [nonce]) |
374 | 395 | self.assertEqual(redirect, "http://client/redirect") |
375 | 396 | |
397 | @override_config({"oidc_config": DEFAULT_CONFIG}) | |
376 | 398 | def test_callback_error(self): |
377 | 399 | """Errors from the provider returned in the callback are displayed.""" |
378 | 400 | request = Mock(args={}) |
384 | 406 | self.get_success(self.handler.handle_oidc_callback(request)) |
385 | 407 | self.assertRenderedError("invalid_client", "some description") |
386 | 408 | |
409 | @override_config({"oidc_config": DEFAULT_CONFIG}) | |
387 | 410 | def test_callback(self): |
388 | 411 | """Code callback works and display errors if something went wrong. |
389 | 412 | |
433 | 456 | self.get_success(self.handler.handle_oidc_callback(request)) |
434 | 457 | |
435 | 458 | auth_handler.complete_sso_login.assert_called_once_with( |
436 | expected_user_id, request, client_redirect_url, None, new_user=True | |
459 | expected_user_id, "oidc", request, client_redirect_url, None, new_user=True | |
437 | 460 | ) |
438 | 461 | self.provider._exchange_code.assert_called_once_with(code) |
439 | 462 | self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) |
464 | 487 | self.get_success(self.handler.handle_oidc_callback(request)) |
465 | 488 | |
466 | 489 | auth_handler.complete_sso_login.assert_called_once_with( |
467 | expected_user_id, request, client_redirect_url, None, new_user=False | |
490 | expected_user_id, "oidc", request, client_redirect_url, None, new_user=False | |
468 | 491 | ) |
469 | 492 | self.provider._exchange_code.assert_called_once_with(code) |
470 | 493 | self.provider._parse_id_token.assert_not_called() |
485 | 508 | self.get_success(self.handler.handle_oidc_callback(request)) |
486 | 509 | self.assertRenderedError("invalid_request") |
487 | 510 | |
511 | @override_config({"oidc_config": DEFAULT_CONFIG}) | |
488 | 512 | def test_callback_session(self): |
489 | 513 | """The callback verifies the session presence and validity""" |
490 | 514 | request = Mock(spec=["args", "getCookie", "cookies"]) |
527 | 551 | self.get_success(self.handler.handle_oidc_callback(request)) |
528 | 552 | self.assertRenderedError("invalid_request") |
529 | 553 | |
530 | @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) | |
554 | @override_config( | |
555 | {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}} | |
556 | ) | |
531 | 557 | def test_exchange_code(self): |
532 | 558 | """Code exchange behaves correctly and handles various error scenarios.""" |
533 | 559 | token = {"type": "bearer"} |
612 | 638 | @override_config( |
613 | 639 | { |
614 | 640 | "oidc_config": { |
641 | "enabled": True, | |
642 | "client_id": CLIENT_ID, | |
643 | "issuer": ISSUER, | |
644 | "client_auth_method": "client_secret_post", | |
645 | "client_secret_jwt_key": { | |
646 | "key_file": _key_file_path(), | |
647 | "jwt_header": {"alg": "ES256", "kid": "ABC789"}, | |
648 | "jwt_payload": {"iss": "DEFGHI"}, | |
649 | }, | |
650 | } | |
651 | } | |
652 | ) | |
653 | def test_exchange_code_jwt_key(self): | |
654 | """Test that code exchange works with a JWK client secret.""" | |
655 | from authlib.jose import jwt | |
656 | ||
657 | token = {"type": "bearer"} | |
658 | self.http_client.request = simple_async_mock( | |
659 | return_value=FakeResponse( | |
660 | code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") | |
661 | ) | |
662 | ) | |
663 | code = "code" | |
664 | ||
665 | # advance the clock a bit before we start, so we aren't working with zero | |
666 | # timestamps. | |
667 | self.reactor.advance(1000) | |
668 | start_time = self.reactor.seconds() | |
669 | ret = self.get_success(self.provider._exchange_code(code)) | |
670 | ||
671 | self.assertEqual(ret, token) | |
672 | ||
673 | # the request should have hit the token endpoint | |
674 | kwargs = self.http_client.request.call_args[1] | |
675 | self.assertEqual(kwargs["method"], "POST") | |
676 | self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) | |
677 | ||
678 | # the client secret provided to the should be a jwt which can be checked with | |
679 | # the public key | |
680 | args = parse_qs(kwargs["data"].decode("utf-8")) | |
681 | secret = args["client_secret"][0] | |
682 | with open(_public_key_file_path()) as f: | |
683 | key = f.read() | |
684 | claims = jwt.decode(secret, key) | |
685 | self.assertEqual(claims.header["kid"], "ABC789") | |
686 | self.assertEqual(claims["aud"], ISSUER) | |
687 | self.assertEqual(claims["iss"], "DEFGHI") | |
688 | self.assertEqual(claims["sub"], CLIENT_ID) | |
689 | self.assertEqual(claims["iat"], start_time) | |
690 | self.assertGreater(claims["exp"], start_time) | |
691 | ||
692 | # check the rest of the POSTed data | |
693 | self.assertEqual(args["grant_type"], ["authorization_code"]) | |
694 | self.assertEqual(args["code"], [code]) | |
695 | self.assertEqual(args["client_id"], [CLIENT_ID]) | |
696 | self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) | |
697 | ||
698 | @override_config( | |
699 | { | |
700 | "oidc_config": { | |
701 | "enabled": True, | |
702 | "client_id": CLIENT_ID, | |
703 | "issuer": ISSUER, | |
704 | "client_auth_method": "none", | |
705 | } | |
706 | } | |
707 | ) | |
708 | def test_exchange_code_no_auth(self): | |
709 | """Test that code exchange works with no client secret.""" | |
710 | token = {"type": "bearer"} | |
711 | self.http_client.request = simple_async_mock( | |
712 | return_value=FakeResponse( | |
713 | code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") | |
714 | ) | |
715 | ) | |
716 | code = "code" | |
717 | ret = self.get_success(self.provider._exchange_code(code)) | |
718 | ||
719 | self.assertEqual(ret, token) | |
720 | ||
721 | # the request should have hit the token endpoint | |
722 | kwargs = self.http_client.request.call_args[1] | |
723 | self.assertEqual(kwargs["method"], "POST") | |
724 | self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) | |
725 | ||
726 | # check the POSTed data | |
727 | args = parse_qs(kwargs["data"].decode("utf-8")) | |
728 | self.assertEqual(args["grant_type"], ["authorization_code"]) | |
729 | self.assertEqual(args["code"], [code]) | |
730 | self.assertEqual(args["client_id"], [CLIENT_ID]) | |
731 | self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) | |
732 | ||
733 | @override_config( | |
734 | { | |
735 | "oidc_config": { | |
736 | **DEFAULT_CONFIG, | |
615 | 737 | "user_mapping_provider": { |
616 | 738 | "module": __name__ + ".TestMappingProviderExtra" |
617 | } | |
739 | }, | |
618 | 740 | } |
619 | 741 | } |
620 | 742 | ) |
650 | 772 | |
651 | 773 | auth_handler.complete_sso_login.assert_called_once_with( |
652 | 774 | "@foo:test", |
775 | "oidc", | |
653 | 776 | request, |
654 | 777 | client_redirect_url, |
655 | 778 | {"phone": "1234567"}, |
656 | 779 | new_user=True, |
657 | 780 | ) |
658 | 781 | |
782 | @override_config({"oidc_config": DEFAULT_CONFIG}) | |
659 | 783 | def test_map_userinfo_to_user(self): |
660 | 784 | """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" |
661 | 785 | auth_handler = self.hs.get_auth_handler() |
667 | 791 | } |
668 | 792 | self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) |
669 | 793 | auth_handler.complete_sso_login.assert_called_once_with( |
670 | "@test_user:test", ANY, ANY, None, new_user=True | |
794 | "@test_user:test", "oidc", ANY, ANY, None, new_user=True | |
671 | 795 | ) |
672 | 796 | auth_handler.complete_sso_login.reset_mock() |
673 | 797 | |
678 | 802 | } |
679 | 803 | self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) |
680 | 804 | auth_handler.complete_sso_login.assert_called_once_with( |
681 | "@test_user_2:test", ANY, ANY, None, new_user=True | |
805 | "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True | |
682 | 806 | ) |
683 | 807 | auth_handler.complete_sso_login.reset_mock() |
684 | 808 | |
696 | 820 | "Mapping provider does not support de-duplicating Matrix IDs", |
697 | 821 | ) |
698 | 822 | |
699 | @override_config({"oidc_config": {"allow_existing_users": True}}) | |
823 | @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) | |
700 | 824 | def test_map_userinfo_to_existing_user(self): |
701 | 825 | """Existing users can log in with OpenID Connect when allow_existing_users is True.""" |
702 | 826 | store = self.hs.get_datastore() |
715 | 839 | } |
716 | 840 | self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) |
717 | 841 | auth_handler.complete_sso_login.assert_called_once_with( |
718 | user.to_string(), ANY, ANY, None, new_user=False | |
842 | user.to_string(), "oidc", ANY, ANY, None, new_user=False | |
719 | 843 | ) |
720 | 844 | auth_handler.complete_sso_login.reset_mock() |
721 | 845 | |
722 | 846 | # Subsequent calls should map to the same mxid. |
723 | 847 | self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) |
724 | 848 | auth_handler.complete_sso_login.assert_called_once_with( |
725 | user.to_string(), ANY, ANY, None, new_user=False | |
849 | user.to_string(), "oidc", ANY, ANY, None, new_user=False | |
726 | 850 | ) |
727 | 851 | auth_handler.complete_sso_login.reset_mock() |
728 | 852 | |
737 | 861 | } |
738 | 862 | self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) |
739 | 863 | auth_handler.complete_sso_login.assert_called_once_with( |
740 | user.to_string(), ANY, ANY, None, new_user=False | |
864 | user.to_string(), "oidc", ANY, ANY, None, new_user=False | |
741 | 865 | ) |
742 | 866 | auth_handler.complete_sso_login.reset_mock() |
743 | 867 | |
773 | 897 | |
774 | 898 | self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) |
775 | 899 | auth_handler.complete_sso_login.assert_called_once_with( |
776 | "@TEST_USER_2:test", ANY, ANY, None, new_user=False | |
777 | ) | |
778 | ||
900 | "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False | |
901 | ) | |
902 | ||
903 | @override_config({"oidc_config": DEFAULT_CONFIG}) | |
779 | 904 | def test_map_userinfo_to_invalid_localpart(self): |
780 | 905 | """If the mapping provider generates an invalid localpart it should be rejected.""" |
781 | 906 | self.get_success( |
786 | 911 | @override_config( |
787 | 912 | { |
788 | 913 | "oidc_config": { |
914 | **DEFAULT_CONFIG, | |
789 | 915 | "user_mapping_provider": { |
790 | 916 | "module": __name__ + ".TestMappingProviderFailures" |
791 | } | |
917 | }, | |
792 | 918 | } |
793 | 919 | } |
794 | 920 | ) |
809 | 935 | |
810 | 936 | # test_user is already taken, so test_user1 gets registered instead. |
811 | 937 | auth_handler.complete_sso_login.assert_called_once_with( |
812 | "@test_user1:test", ANY, ANY, None, new_user=True | |
938 | "@test_user1:test", "oidc", ANY, ANY, None, new_user=True | |
813 | 939 | ) |
814 | 940 | auth_handler.complete_sso_login.reset_mock() |
815 | 941 | |
833 | 959 | "mapping_error", "Unable to generate a Matrix ID from the SSO response" |
834 | 960 | ) |
835 | 961 | |
962 | @override_config({"oidc_config": DEFAULT_CONFIG}) | |
836 | 963 | def test_empty_localpart(self): |
837 | 964 | """Attempts to map onto an empty localpart should be rejected.""" |
838 | 965 | userinfo = { |
845 | 972 | @override_config( |
846 | 973 | { |
847 | 974 | "oidc_config": { |
975 | **DEFAULT_CONFIG, | |
848 | 976 | "user_mapping_provider": { |
849 | 977 | "config": {"localpart_template": "{{ user.username }}"} |
850 | } | |
978 | }, | |
851 | 979 | } |
852 | 980 | } |
853 | 981 | ) |
865 | 993 | state: str, |
866 | 994 | nonce: str, |
867 | 995 | client_redirect_url: str, |
868 | ui_auth_session_id: Optional[str] = None, | |
996 | ui_auth_session_id: str = "", | |
869 | 997 | ) -> str: |
870 | 998 | from synapse.handlers.oidc_handler import OidcSessionData |
871 | 999 | |
908 | 1036 | idp_id="oidc", |
909 | 1037 | nonce="nonce", |
910 | 1038 | client_redirect_url=client_redirect_url, |
1039 | ui_auth_session_id="", | |
911 | 1040 | ), |
912 | 1041 | ) |
913 | 1042 | request = _build_callback_request("code", state, session) |
516 | 516 | |
517 | 517 | self.assertTrue(requester.shadow_banned) |
518 | 518 | |
519 | def test_spam_checker_receives_sso_type(self): | |
520 | """Test rejecting registration based on SSO type""" | |
521 | ||
522 | class BanBadIdPUser: | |
523 | def check_registration_for_spam( | |
524 | self, email_threepid, username, request_info, auth_provider_id=None | |
525 | ): | |
526 | # Reject any user coming from CAS and whose username contains profanity | |
527 | if auth_provider_id == "cas" and "flimflob" in username: | |
528 | return RegistrationBehaviour.DENY | |
529 | return RegistrationBehaviour.ALLOW | |
530 | ||
531 | # Configure a spam checker that denies a certain user on a specific IdP | |
532 | spam_checker = self.hs.get_spam_checker() | |
533 | spam_checker.spam_checkers = [BanBadIdPUser()] | |
534 | ||
535 | f = self.get_failure( | |
536 | self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"), | |
537 | SynapseError, | |
538 | ) | |
539 | exception = f.value | |
540 | ||
541 | # We return 429 from the spam checker for denied registrations | |
542 | self.assertIsInstance(exception, SynapseError) | |
543 | self.assertEqual(exception.code, 429) | |
544 | ||
545 | # Check the same username can register using SAML | |
546 | self.get_success( | |
547 | self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml") | |
548 | ) | |
549 | ||
519 | 550 | async def get_or_create_user( |
520 | 551 | self, requester, localpart, displayname, password_hash=None |
521 | 552 | ): |
130 | 130 | |
131 | 131 | # check that the auth handler got called as expected |
132 | 132 | auth_handler.complete_sso_login.assert_called_once_with( |
133 | "@test_user:test", request, "redirect_uri", None, new_user=True | |
133 | "@test_user:test", "saml", request, "redirect_uri", None, new_user=True | |
134 | 134 | ) |
135 | 135 | |
136 | 136 | @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) |
156 | 156 | |
157 | 157 | # check that the auth handler got called as expected |
158 | 158 | auth_handler.complete_sso_login.assert_called_once_with( |
159 | "@test_user:test", request, "", None, new_user=False | |
159 | "@test_user:test", "saml", request, "", None, new_user=False | |
160 | 160 | ) |
161 | 161 | |
162 | 162 | # Subsequent calls should map to the same mxid. |
165 | 165 | self.handler._handle_authn_response(request, saml_response, "") |
166 | 166 | ) |
167 | 167 | auth_handler.complete_sso_login.assert_called_once_with( |
168 | "@test_user:test", request, "", None, new_user=False | |
168 | "@test_user:test", "saml", request, "", None, new_user=False | |
169 | 169 | ) |
170 | 170 | |
171 | 171 | def test_map_saml_response_to_invalid_localpart(self): |
213 | 213 | |
214 | 214 | # test_user is already taken, so test_user1 gets registered instead. |
215 | 215 | auth_handler.complete_sso_login.assert_called_once_with( |
216 | "@test_user1:test", request, "", None, new_user=True | |
216 | "@test_user1:test", "saml", request, "", None, new_user=True | |
217 | 217 | ) |
218 | 218 | auth_handler.complete_sso_login.reset_mock() |
219 | 219 | |
309 | 309 | |
310 | 310 | # check that the auth handler got called as expected |
311 | 311 | auth_handler.complete_sso_login.assert_called_once_with( |
312 | "@test_user:test", request, "redirect_uri", None, new_user=True | |
312 | "@test_user:test", "saml", request, "redirect_uri", None, new_user=True | |
313 | 313 | ) |
314 | 314 | |
315 | 315 |
15 | 15 | |
16 | 16 | from mock import Mock |
17 | 17 | |
18 | from netaddr import IPSet | |
19 | ||
20 | from twisted.internet.error import DNSLookupError | |
18 | 21 | from twisted.python.failure import Failure |
19 | from twisted.web.client import ResponseDone | |
22 | from twisted.test.proto_helpers import AccumulatingProtocol | |
23 | from twisted.web.client import Agent, ResponseDone | |
20 | 24 | from twisted.web.iweb import UNKNOWN_LENGTH |
21 | 25 | |
22 | from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size | |
23 | ||
26 | from synapse.api.errors import SynapseError | |
27 | from synapse.http.client import ( | |
28 | BlacklistingAgentWrapper, | |
29 | BlacklistingReactorWrapper, | |
30 | BodyExceededMaxSize, | |
31 | read_body_with_max_size, | |
32 | ) | |
33 | ||
34 | from tests.server import FakeTransport, get_clock | |
24 | 35 | from tests.unittest import TestCase |
25 | 36 | |
26 | 37 | |
118 | 129 | |
119 | 130 | # The data is never consumed. |
120 | 131 | self.assertEqual(result.getvalue(), b"") |
132 | ||
133 | ||
134 | class BlacklistingAgentTest(TestCase): | |
135 | def setUp(self): | |
136 | self.reactor, self.clock = get_clock() | |
137 | ||
138 | self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4" | |
139 | self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8" | |
140 | self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1" | |
141 | ||
142 | # Configure the reactor's DNS resolver. | |
143 | for (domain, ip) in ( | |
144 | (self.safe_domain, self.safe_ip), | |
145 | (self.unsafe_domain, self.unsafe_ip), | |
146 | (self.allowed_domain, self.allowed_ip), | |
147 | ): | |
148 | self.reactor.lookups[domain.decode()] = ip.decode() | |
149 | self.reactor.lookups[ip.decode()] = ip.decode() | |
150 | ||
151 | self.ip_whitelist = IPSet([self.allowed_ip.decode()]) | |
152 | self.ip_blacklist = IPSet(["5.0.0.0/8"]) | |
153 | ||
154 | def test_reactor(self): | |
155 | """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs.""" | |
156 | agent = Agent( | |
157 | BlacklistingReactorWrapper( | |
158 | self.reactor, | |
159 | ip_whitelist=self.ip_whitelist, | |
160 | ip_blacklist=self.ip_blacklist, | |
161 | ), | |
162 | ) | |
163 | ||
164 | # The unsafe domains and IPs should be rejected. | |
165 | for domain in (self.unsafe_domain, self.unsafe_ip): | |
166 | self.failureResultOf( | |
167 | agent.request(b"GET", b"http://" + domain), DNSLookupError | |
168 | ) | |
169 | ||
170 | # The safe domains IPs should be accepted. | |
171 | for domain in ( | |
172 | self.safe_domain, | |
173 | self.allowed_domain, | |
174 | self.safe_ip, | |
175 | self.allowed_ip, | |
176 | ): | |
177 | d = agent.request(b"GET", b"http://" + domain) | |
178 | ||
179 | # Grab the latest TCP connection. | |
180 | ( | |
181 | host, | |
182 | port, | |
183 | client_factory, | |
184 | _timeout, | |
185 | _bindAddress, | |
186 | ) = self.reactor.tcpClients[-1] | |
187 | ||
188 | # Make the connection and pump data through it. | |
189 | client = client_factory.buildProtocol(None) | |
190 | server = AccumulatingProtocol() | |
191 | server.makeConnection(FakeTransport(client, self.reactor)) | |
192 | client.makeConnection(FakeTransport(server, self.reactor)) | |
193 | client.dataReceived( | |
194 | b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n" | |
195 | ) | |
196 | ||
197 | response = self.successResultOf(d) | |
198 | self.assertEqual(response.code, 200) | |
199 | ||
200 | def test_agent(self): | |
201 | """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs.""" | |
202 | agent = BlacklistingAgentWrapper( | |
203 | Agent(self.reactor), | |
204 | ip_whitelist=self.ip_whitelist, | |
205 | ip_blacklist=self.ip_blacklist, | |
206 | ) | |
207 | ||
208 | # The unsafe IPs should be rejected. | |
209 | self.failureResultOf( | |
210 | agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError | |
211 | ) | |
212 | ||
213 | # The safe and unsafe domains and safe IPs should be accepted. | |
214 | for domain in ( | |
215 | self.safe_domain, | |
216 | self.unsafe_domain, | |
217 | self.allowed_domain, | |
218 | self.safe_ip, | |
219 | self.allowed_ip, | |
220 | ): | |
221 | d = agent.request(b"GET", b"http://" + domain) | |
222 | ||
223 | # Grab the latest TCP connection. | |
224 | ( | |
225 | host, | |
226 | port, | |
227 | client_factory, | |
228 | _timeout, | |
229 | _bindAddress, | |
230 | ) = self.reactor.tcpClients[-1] | |
231 | ||
232 | # Make the connection and pump data through it. | |
233 | client = client_factory.buildProtocol(None) | |
234 | server = AccumulatingProtocol() | |
235 | server.makeConnection(FakeTransport(client, self.reactor)) | |
236 | client.makeConnection(FakeTransport(server, self.reactor)) | |
237 | client.dataReceived( | |
238 | b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n" | |
239 | ) | |
240 | ||
241 | response = self.successResultOf(d) | |
242 | self.assertEqual(response.code, 200) |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import logging |
15 | from typing import Any, Callable, Dict, List, Optional, Tuple | |
16 | ||
17 | import attr | |
15 | from typing import Any, Callable, Dict, List, Optional, Tuple, Type | |
18 | 16 | |
19 | 17 | from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime |
20 | 18 | from twisted.internet.protocol import Protocol |
21 | 19 | from twisted.internet.task import LoopingCall |
22 | 20 | from twisted.web.http import HTTPChannel |
23 | 21 | from twisted.web.resource import Resource |
22 | from twisted.web.server import Request, Site | |
24 | 23 | |
25 | 24 | from synapse.app.generic_worker import ( |
26 | 25 | GenericWorkerReplicationHandler, |
31 | 30 | from synapse.replication.http import ReplicationRestResource |
32 | 31 | from synapse.replication.tcp.handler import ReplicationCommandHandler |
33 | 32 | from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol |
34 | from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory | |
33 | from synapse.replication.tcp.resource import ( | |
34 | ReplicationStreamProtocolFactory, | |
35 | ServerReplicationStreamProtocol, | |
36 | ) | |
35 | 37 | from synapse.server import HomeServer |
36 | 38 | from synapse.util import Clock |
37 | 39 | |
58 | 60 | # build a replication server |
59 | 61 | server_factory = ReplicationStreamProtocolFactory(hs) |
60 | 62 | self.streamer = hs.get_replication_streamer() |
61 | self.server = server_factory.buildProtocol(None) | |
63 | self.server = server_factory.buildProtocol( | |
64 | None | |
65 | ) # type: ServerReplicationStreamProtocol | |
62 | 66 | |
63 | 67 | # Make a new HomeServer object for the worker |
64 | 68 | self.reactor.lookups["testserv"] = "1.2.3.4" |
151 | 155 | # Set up client side protocol |
152 | 156 | client_protocol = client_factory.buildProtocol(None) |
153 | 157 | |
154 | request_factory = OneShotRequestFactory() | |
155 | ||
156 | 158 | # Set up the server side protocol |
157 | channel = _PushHTTPChannel(self.reactor) | |
158 | channel.requestFactory = request_factory | |
159 | channel.site = self.site | |
159 | channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site) | |
160 | 160 | |
161 | 161 | # Connect client to server and vice versa. |
162 | 162 | client_to_server_transport = FakeTransport( |
178 | 178 | server_to_client_transport.loseConnection() |
179 | 179 | client_to_server_transport.loseConnection() |
180 | 180 | |
181 | return request_factory.request | |
181 | return channel.request | |
182 | 182 | |
183 | 183 | def assert_request_is_get_repl_stream_updates( |
184 | 184 | self, request: SynapseRequest, stream_name: str |
187 | 187 | fetching updates for given stream. |
188 | 188 | """ |
189 | 189 | |
190 | path = request.path # type: bytes # type: ignore | |
190 | 191 | self.assertRegex( |
191 | request.path, | |
192 | path, | |
192 | 193 | br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" |
193 | 194 | % (stream_name.encode("ascii"),), |
194 | 195 | ) |
231 | 232 | if self.hs.config.redis.redis_enabled: |
232 | 233 | # Handle attempts to connect to fake redis server. |
233 | 234 | self.reactor.add_tcp_client_callback( |
234 | "localhost", | |
235 | b"localhost", | |
235 | 236 | 6379, |
236 | 237 | self.connect_any_redis_attempts, |
237 | 238 | ) |
386 | 387 | # Set up client side protocol |
387 | 388 | client_protocol = client_factory.buildProtocol(None) |
388 | 389 | |
389 | request_factory = OneShotRequestFactory() | |
390 | ||
391 | 390 | # Set up the server side protocol |
392 | channel = _PushHTTPChannel(self.reactor) | |
393 | channel.requestFactory = request_factory | |
394 | channel.site = self._hs_to_site[hs] | |
391 | channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs]) | |
395 | 392 | |
396 | 393 | # Connect client to server and vice versa. |
397 | 394 | client_to_server_transport = FakeTransport( |
417 | 414 | clients = self.reactor.tcpClients |
418 | 415 | while clients: |
419 | 416 | (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) |
420 | self.assertEqual(host, "localhost") | |
417 | self.assertEqual(host, b"localhost") | |
421 | 418 | self.assertEqual(port, 6379) |
422 | 419 | |
423 | 420 | client_protocol = client_factory.buildProtocol(None) |
447 | 444 | await super().on_rdata(stream_name, instance_name, token, rows) |
448 | 445 | for r in rows: |
449 | 446 | self.received_rdata_rows.append((stream_name, token, r)) |
450 | ||
451 | ||
452 | @attr.s() | |
453 | class OneShotRequestFactory: | |
454 | """A simple request factory that generates a single `SynapseRequest` and | |
455 | stores it for future use. Can only be used once. | |
456 | """ | |
457 | ||
458 | request = attr.ib(default=None) | |
459 | ||
460 | def __call__(self, *args, **kwargs): | |
461 | assert self.request is None | |
462 | ||
463 | self.request = SynapseRequest(*args, **kwargs) | |
464 | return self.request | |
465 | 447 | |
466 | 448 | |
467 | 449 | class _PushHTTPChannel(HTTPChannel): |
474 | 456 | makes it very hard to test. |
475 | 457 | """ |
476 | 458 | |
477 | def __init__(self, reactor: IReactorTime): | |
459 | def __init__( | |
460 | self, reactor: IReactorTime, request_factory: Type[Request], site: Site | |
461 | ): | |
478 | 462 | super().__init__() |
479 | 463 | self.reactor = reactor |
464 | self.requestFactory = request_factory | |
465 | self.site = site | |
480 | 466 | |
481 | 467 | self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] |
482 | 468 | |
501 | 487 | # `handle_http_replication_attempt`. |
502 | 488 | request.responseHeaders.setRawHeaders(b"connection", [b"close"]) |
503 | 489 | return False |
490 | ||
491 | def requestDone(self, request): | |
492 | # Store the request for inspection. | |
493 | self.request = request | |
494 | super().requestDone(request) | |
504 | 495 | |
505 | 496 | |
506 | 497 | class _PullToPushProducer: |
588 | 579 | |
589 | 580 | class FakeRedisPubSubProtocol(Protocol): |
590 | 581 | """A connection from a client talking to the fake Redis server.""" |
582 | ||
583 | transport = None # type: Optional[FakeTransport] | |
591 | 584 | |
592 | 585 | def __init__(self, server: FakeRedisPubSubServer): |
593 | 586 | self._server = server |
633 | 626 | |
634 | 627 | def send(self, msg): |
635 | 628 | """Send a message back to the client.""" |
629 | assert self.transport is not None | |
630 | ||
636 | 631 | raw = self.encode(msg).encode("utf-8") |
637 | 632 | |
638 | 633 | self.transport.write(raw) |
16 | 16 | |
17 | 17 | from synapse.app.generic_worker import GenericWorkerServer |
18 | 18 | from synapse.replication.tcp.commands import FederationAckCommand |
19 | from synapse.replication.tcp.protocol import AbstractConnection | |
19 | from synapse.replication.tcp.protocol import IReplicationConnection | |
20 | 20 | from synapse.replication.tcp.streams.federation import FederationStream |
21 | 21 | |
22 | 22 | from tests.unittest import HomeserverTestCase |
50 | 50 | """ |
51 | 51 | rch = self.hs.get_tcp_replication() |
52 | 52 | |
53 | # wire up the ReplicationCommandHandler to a mock connection | |
54 | mock_connection = mock.Mock(spec=AbstractConnection) | |
53 | # wire up the ReplicationCommandHandler to a mock connection, which needs | |
54 | # to implement IReplicationConnection. (Note that Mock doesn't understand | |
55 | # interfaces, but casing an interface to a list gives the attributes.) | |
56 | mock_connection = mock.Mock(spec=list(IReplicationConnection)) | |
55 | 57 | rch.new_connection(mock_connection) |
56 | 58 | |
57 | 59 | # tell it it received an RDATA row |
436 | 436 | channel = self.make_request("GET", "/_matrix/client/r0/login") |
437 | 437 | self.assertEqual(channel.code, 200, channel.result) |
438 | 438 | |
439 | expected_flows = [ | |
440 | {"type": "m.login.cas"}, | |
441 | {"type": "m.login.sso"}, | |
442 | {"type": "m.login.token"}, | |
443 | {"type": "m.login.password"}, | |
444 | ] + ADDITIONAL_LOGIN_FLOWS | |
445 | ||
446 | self.assertCountEqual(channel.json_body["flows"], expected_flows) | |
439 | expected_flow_types = [ | |
440 | "m.login.cas", | |
441 | "m.login.sso", | |
442 | "m.login.token", | |
443 | "m.login.password", | |
444 | ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] | |
445 | ||
446 | self.assertCountEqual( | |
447 | [f["type"] for f in channel.json_body["flows"]], expected_flow_types | |
448 | ) | |
447 | 449 | |
448 | 450 | @override_config({"experimental_features": {"msc2858_enabled": True}}) |
449 | 451 | def test_get_msc2858_login_flows(self): |
635 | 637 | ) |
636 | 638 | self.assertEqual(channel.code, 400, channel.result) |
637 | 639 | |
638 | def test_client_idp_redirect_msc2858_disabled(self): | |
639 | """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" | |
640 | channel = self._make_sso_redirect_request(True, "oidc") | |
641 | self.assertEqual(channel.code, 400, channel.result) | |
642 | self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") | |
643 | ||
644 | @override_config({"experimental_features": {"msc2858_enabled": True}}) | |
645 | 640 | def test_client_idp_redirect_to_unknown(self): |
646 | 641 | """If the client tries to pick an unknown IdP, return a 404""" |
647 | channel = self._make_sso_redirect_request(True, "xxx") | |
642 | channel = self._make_sso_redirect_request(False, "xxx") | |
648 | 643 | self.assertEqual(channel.code, 404, channel.result) |
649 | 644 | self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") |
650 | 645 | |
651 | @override_config({"experimental_features": {"msc2858_enabled": True}}) | |
652 | 646 | def test_client_idp_redirect_to_oidc(self): |
653 | 647 | """If the client pick a known IdP, redirect to it""" |
648 | channel = self._make_sso_redirect_request(False, "oidc") | |
649 | self.assertEqual(channel.code, 302, channel.result) | |
650 | oidc_uri = channel.headers.getRawHeaders("Location")[0] | |
651 | oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) | |
652 | ||
653 | # it should redirect us to the auth page of the OIDC server | |
654 | self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) | |
655 | ||
656 | @override_config({"experimental_features": {"msc2858_enabled": True}}) | |
657 | def test_client_msc2858_redirect_to_oidc(self): | |
658 | """Test the unstable API""" | |
654 | 659 | channel = self._make_sso_redirect_request(True, "oidc") |
655 | 660 | self.assertEqual(channel.code, 302, channel.result) |
656 | 661 | oidc_uri = channel.headers.getRawHeaders("Location")[0] |
658 | 663 | |
659 | 664 | # it should redirect us to the auth page of the OIDC server |
660 | 665 | self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) |
666 | ||
667 | def test_client_idp_redirect_msc2858_disabled(self): | |
668 | """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" | |
669 | channel = self._make_sso_redirect_request(True, "oidc") | |
670 | self.assertEqual(channel.code, 400, channel.result) | |
671 | self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") | |
661 | 672 | |
662 | 673 | def _make_sso_redirect_request( |
663 | 674 | self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None |
104 | 104 | self.assertEqual(test_body, body) |
105 | 105 | |
106 | 106 | |
107 | @attr.s | |
107 | @attr.s(slots=True, frozen=True) | |
108 | 108 | class _TestImage: |
109 | 109 | """An image for testing thumbnailing with the expected results |
110 | 110 | |
116 | 116 | test should just check for success. |
117 | 117 | expected_scaled: The expected bytes from scaled thumbnailing, or None if |
118 | 118 | test should just check for a valid image returned. |
119 | expected_found: True if the file should exist on the server, or False if | |
120 | a 404 is expected. | |
119 | 121 | """ |
120 | 122 | |
121 | 123 | data = attr.ib(type=bytes) |
122 | 124 | content_type = attr.ib(type=bytes) |
123 | 125 | extension = attr.ib(type=bytes) |
124 | expected_cropped = attr.ib(type=Optional[bytes]) | |
125 | expected_scaled = attr.ib(type=Optional[bytes]) | |
126 | expected_cropped = attr.ib(type=Optional[bytes], default=None) | |
127 | expected_scaled = attr.ib(type=Optional[bytes], default=None) | |
126 | 128 | expected_found = attr.ib(default=True, type=bool) |
127 | 129 | |
128 | 130 | |
152 | 154 | ), |
153 | 155 | ), |
154 | 156 | ), |
157 | # small png with transparency. | |
158 | ( | |
159 | _TestImage( | |
160 | unhexlify( | |
161 | b"89504e470d0a1a0a0000000d49484452000000010000000101000" | |
162 | b"00000376ef9240000000274524e5300010194fdae0000000a4944" | |
163 | b"4154789c636800000082008177cd72b60000000049454e44ae426" | |
164 | b"082" | |
165 | ), | |
166 | b"image/png", | |
167 | b".png", | |
168 | # Note that we don't check the output since it varies across | |
169 | # different versions of Pillow. | |
170 | ), | |
171 | ), | |
155 | 172 | # small lossless webp |
156 | 173 | ( |
157 | 174 | _TestImage( |
161 | 178 | ), |
162 | 179 | b"image/webp", |
163 | 180 | b".webp", |
164 | None, | |
165 | None, | |
166 | 181 | ), |
167 | 182 | ), |
168 | 183 | # an empty file |
171 | 186 | b"", |
172 | 187 | b"image/gif", |
173 | 188 | b".gif", |
174 | None, | |
175 | None, | |
176 | False, | |
189 | expected_found=False, | |
177 | 190 | ), |
178 | 191 | ), |
179 | 192 | ], |
15 | 15 | IReactorPluggableNameResolver, |
16 | 16 | IReactorTCP, |
17 | 17 | IResolverSimple, |
18 | ITransport, | |
18 | 19 | ) |
19 | 20 | from twisted.python.failure import Failure |
20 | 21 | from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock |
187 | 188 | |
188 | 189 | def make_request( |
189 | 190 | reactor, |
190 | site: Site, | |
191 | site: Union[Site, FakeSite], | |
191 | 192 | method, |
192 | 193 | path, |
193 | 194 | content=b"", |
466 | 467 | return clock, hs_clock |
467 | 468 | |
468 | 469 | |
470 | @implementer(ITransport) | |
469 | 471 | @attr.s(cmp=False) |
470 | 472 | class FakeTransport: |
471 | 473 | """ |
117 | 117 | r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) |
118 | 118 | self.assertTrue(r == [room2] or r == [room3]) |
119 | 119 | |
120 | @parameterized.expand([(True,), (False,)]) | |
121 | def test_auth_difference(self, use_chain_cover_index: bool): | |
120 | def _setup_auth_chain(self, use_chain_cover_index: bool) -> str: | |
122 | 121 | room_id = "@ROOM:local" |
123 | 122 | |
124 | 123 | # The silly auth graph we use to test the auth difference algorithm, |
164 | 163 | "j": 1, |
165 | 164 | } |
166 | 165 | |
167 | # Mark the room as not having a cover index | |
166 | # Mark the room as maybe having a cover index. | |
168 | 167 | |
169 | 168 | def store_room(txn): |
170 | 169 | self.store.db_pool.simple_insert_txn( |
221 | 220 | ) |
222 | 221 | ) |
223 | 222 | |
223 | return room_id | |
224 | ||
225 | @parameterized.expand([(True,), (False,)]) | |
226 | def test_auth_chain_ids(self, use_chain_cover_index: bool): | |
227 | room_id = self._setup_auth_chain(use_chain_cover_index) | |
228 | ||
229 | # a and b have the same auth chain. | |
230 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"])) | |
231 | self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) | |
232 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"])) | |
233 | self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) | |
234 | auth_chain_ids = self.get_success( | |
235 | self.store.get_auth_chain_ids(room_id, ["a", "b"]) | |
236 | ) | |
237 | self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) | |
238 | ||
239 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"])) | |
240 | self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) | |
241 | ||
242 | # d and e have the same auth chain. | |
243 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"])) | |
244 | self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) | |
245 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"])) | |
246 | self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) | |
247 | ||
248 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"])) | |
249 | self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) | |
250 | ||
251 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"])) | |
252 | self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"]) | |
253 | ||
254 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"])) | |
255 | self.assertEqual(auth_chain_ids, ["k"]) | |
256 | ||
257 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"])) | |
258 | self.assertEqual(auth_chain_ids, ["j"]) | |
259 | ||
260 | # j and k have no parents. | |
261 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"])) | |
262 | self.assertEqual(auth_chain_ids, []) | |
263 | auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"])) | |
264 | self.assertEqual(auth_chain_ids, []) | |
265 | ||
266 | # More complex input sequences. | |
267 | auth_chain_ids = self.get_success( | |
268 | self.store.get_auth_chain_ids(room_id, ["b", "c", "d"]) | |
269 | ) | |
270 | self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) | |
271 | ||
272 | auth_chain_ids = self.get_success( | |
273 | self.store.get_auth_chain_ids(room_id, ["h", "i"]) | |
274 | ) | |
275 | self.assertCountEqual(auth_chain_ids, ["k", "j"]) | |
276 | ||
277 | # e gets returned even though include_given is false, but it is in the | |
278 | # auth chain of b. | |
279 | auth_chain_ids = self.get_success( | |
280 | self.store.get_auth_chain_ids(room_id, ["b", "e"]) | |
281 | ) | |
282 | self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) | |
283 | ||
284 | # Test include_given. | |
285 | auth_chain_ids = self.get_success( | |
286 | self.store.get_auth_chain_ids(room_id, ["i"], include_given=True) | |
287 | ) | |
288 | self.assertCountEqual(auth_chain_ids, ["i", "j"]) | |
289 | ||
290 | @parameterized.expand([(True,), (False,)]) | |
291 | def test_auth_difference(self, use_chain_cover_index: bool): | |
292 | room_id = self._setup_auth_chain(use_chain_cover_index) | |
293 | ||
224 | 294 | # Now actually test that various combinations give the right result: |
225 | 295 | |
226 | 296 | difference = self.get_success( |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | from twisted.internet import defer | |
16 | ||
17 | from synapse.api.errors import NotFoundError | |
15 | from synapse.api.errors import NotFoundError, SynapseError | |
18 | 16 | from synapse.rest.client.v1 import room |
19 | 17 | |
20 | 18 | from tests.unittest import HomeserverTestCase |
32 | 30 | def prepare(self, reactor, clock, hs): |
33 | 31 | self.room_id = self.helper.create_room_as(self.user_id) |
34 | 32 | |
35 | def test_purge(self): | |
33 | self.store = hs.get_datastore() | |
34 | self.storage = self.hs.get_storage() | |
35 | ||
36 | def test_purge_history(self): | |
36 | 37 | """ |
37 | Purging a room will delete everything before the topological point. | |
38 | Purging a room history will delete everything before the topological point. | |
38 | 39 | """ |
39 | 40 | # Send four messages to the room |
40 | 41 | first = self.helper.send(self.room_id, body="test1") |
42 | 43 | third = self.helper.send(self.room_id, body="test3") |
43 | 44 | last = self.helper.send(self.room_id, body="test4") |
44 | 45 | |
45 | store = self.hs.get_datastore() | |
46 | storage = self.hs.get_storage() | |
47 | ||
48 | 46 | # Get the topological token |
49 | 47 | token = self.get_success( |
50 | store.get_topological_token_for_event(last["event_id"]) | |
48 | self.store.get_topological_token_for_event(last["event_id"]) | |
51 | 49 | ) |
52 | 50 | token_str = self.get_success(token.to_string(self.hs.get_datastore())) |
53 | 51 | |
54 | 52 | # Purge everything before this topological token |
55 | 53 | self.get_success( |
56 | storage.purge_events.purge_history(self.room_id, token_str, True) | |
54 | self.storage.purge_events.purge_history(self.room_id, token_str, True) | |
57 | 55 | ) |
58 | 56 | |
59 | 57 | # 1-3 should fail and last will succeed, meaning that 1-3 are deleted |
60 | 58 | # and last is not. |
61 | self.get_failure(store.get_event(first["event_id"]), NotFoundError) | |
62 | self.get_failure(store.get_event(second["event_id"]), NotFoundError) | |
63 | self.get_failure(store.get_event(third["event_id"]), NotFoundError) | |
64 | self.get_success(store.get_event(last["event_id"])) | |
59 | self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) | |
60 | self.get_failure(self.store.get_event(second["event_id"]), NotFoundError) | |
61 | self.get_failure(self.store.get_event(third["event_id"]), NotFoundError) | |
62 | self.get_success(self.store.get_event(last["event_id"])) | |
65 | 63 | |
66 | def test_purge_wont_delete_extrems(self): | |
64 | def test_purge_history_wont_delete_extrems(self): | |
67 | 65 | """ |
68 | Purging a room will delete everything before the topological point. | |
66 | Purging a room history will delete everything before the topological point. | |
69 | 67 | """ |
70 | 68 | # Send four messages to the room |
71 | 69 | first = self.helper.send(self.room_id, body="test1") |
73 | 71 | third = self.helper.send(self.room_id, body="test3") |
74 | 72 | last = self.helper.send(self.room_id, body="test4") |
75 | 73 | |
76 | storage = self.hs.get_datastore() | |
77 | ||
78 | 74 | # Set the topological token higher than it should be |
79 | 75 | token = self.get_success( |
80 | storage.get_topological_token_for_event(last["event_id"]) | |
76 | self.store.get_topological_token_for_event(last["event_id"]) | |
81 | 77 | ) |
82 | 78 | event = "t{}-{}".format(token.topological + 1, token.stream + 1) |
83 | 79 | |
84 | 80 | # Purge everything before this topological token |
85 | purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) | |
86 | self.pump() | |
87 | f = self.failureResultOf(purge) | |
81 | f = self.get_failure( | |
82 | self.storage.purge_events.purge_history(self.room_id, event, True), | |
83 | SynapseError, | |
84 | ) | |
88 | 85 | self.assertIn("greater than forward", f.value.args[0]) |
89 | 86 | |
90 | 87 | # Try and get the events |
91 | self.get_success(storage.get_event(first["event_id"])) | |
92 | self.get_success(storage.get_event(second["event_id"])) | |
93 | self.get_success(storage.get_event(third["event_id"])) | |
94 | self.get_success(storage.get_event(last["event_id"])) | |
88 | self.get_success(self.store.get_event(first["event_id"])) | |
89 | self.get_success(self.store.get_event(second["event_id"])) | |
90 | self.get_success(self.store.get_event(third["event_id"])) | |
91 | self.get_success(self.store.get_event(last["event_id"])) | |
92 | ||
93 | def test_purge_room(self): | |
94 | """ | |
95 | Purging a room will delete everything about it. | |
96 | """ | |
97 | # Send four messages to the room | |
98 | first = self.helper.send(self.room_id, body="test1") | |
99 | ||
100 | # Get the current room state. | |
101 | state_handler = self.hs.get_state_handler() | |
102 | create_event = self.get_success( | |
103 | state_handler.get_current_state(self.room_id, "m.room.create", "") | |
104 | ) | |
105 | self.assertIsNotNone(create_event) | |
106 | ||
107 | # Purge everything before this topological token | |
108 | self.get_success(self.storage.purge_events.purge_room(self.room_id)) | |
109 | ||
110 | # The events aren't found. | |
111 | self.store._invalidate_get_event_cache(create_event.event_id) | |
112 | self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) | |
113 | self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) |
27 | 27 | def emit(self, record): |
28 | 28 | log_entry = self.format(record) |
29 | 29 | log_level = record.levelname.lower().replace("warning", "warn") |
30 | self.tx_log.emit( | |
30 | self.tx_log.emit( # type: ignore | |
31 | 31 | twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry |
32 | 32 | ) |
33 | 33 |
0 | # Copyright 2021 The Matrix.org Foundation C.I.C. | |
1 | # | |
2 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | # you may not use this file except in compliance with the License. | |
4 | # You may obtain a copy of the License at | |
5 | # | |
6 | # http://www.apache.org/licenses/LICENSE-2.0 | |
7 | # | |
8 | # Unless required by applicable law or agreed to in writing, software | |
9 | # distributed under the License is distributed on an "AS IS" BASIS, | |
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | # See the License for the specific language governing permissions and | |
12 | # limitations under the License. | |
13 | ||
14 | from synapse.util.caches.response_cache import ResponseCache | |
15 | ||
16 | from tests.server import get_clock | |
17 | from tests.unittest import TestCase | |
18 | ||
19 | ||
20 | class DeferredCacheTestCase(TestCase): | |
21 | """ | |
22 | A TestCase class for ResponseCache. | |
23 | ||
24 | The test-case function naming has some logic to it in it's parts, here's some notes about it: | |
25 | wait: Denotes tests that have an element of "waiting" before its wrapped result becomes available | |
26 | (Generally these just use .delayed_return instead of .instant_return in it's wrapped call.) | |
27 | expire: Denotes tests that test expiry after assured existence. | |
28 | (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock) | |
29 | """ | |
30 | ||
31 | def setUp(self): | |
32 | self.reactor, self.clock = get_clock() | |
33 | ||
34 | def with_cache(self, name: str, ms: int = 0) -> ResponseCache: | |
35 | return ResponseCache(self.clock, name, timeout_ms=ms) | |
36 | ||
37 | @staticmethod | |
38 | async def instant_return(o: str) -> str: | |
39 | return o | |
40 | ||
41 | async def delayed_return(self, o: str) -> str: | |
42 | await self.clock.sleep(1) | |
43 | return o | |
44 | ||
45 | def test_cache_hit(self): | |
46 | cache = self.with_cache("keeping_cache", ms=9001) | |
47 | ||
48 | expected_result = "howdy" | |
49 | ||
50 | wrap_d = cache.wrap(0, self.instant_return, expected_result) | |
51 | ||
52 | self.assertEqual( | |
53 | expected_result, | |
54 | self.successResultOf(wrap_d), | |
55 | "initial wrap result should be the same", | |
56 | ) | |
57 | self.assertEqual( | |
58 | expected_result, | |
59 | self.successResultOf(cache.get(0)), | |
60 | "cache should have the result", | |
61 | ) | |
62 | ||
63 | def test_cache_miss(self): | |
64 | cache = self.with_cache("trashing_cache", ms=0) | |
65 | ||
66 | expected_result = "howdy" | |
67 | ||
68 | wrap_d = cache.wrap(0, self.instant_return, expected_result) | |
69 | ||
70 | self.assertEqual( | |
71 | expected_result, | |
72 | self.successResultOf(wrap_d), | |
73 | "initial wrap result should be the same", | |
74 | ) | |
75 | self.assertIsNone(cache.get(0), "cache should not have the result now") | |
76 | ||
77 | def test_cache_expire(self): | |
78 | cache = self.with_cache("short_cache", ms=1000) | |
79 | ||
80 | expected_result = "howdy" | |
81 | ||
82 | wrap_d = cache.wrap(0, self.instant_return, expected_result) | |
83 | ||
84 | self.assertEqual(expected_result, self.successResultOf(wrap_d)) | |
85 | self.assertEqual( | |
86 | expected_result, | |
87 | self.successResultOf(cache.get(0)), | |
88 | "cache should still have the result", | |
89 | ) | |
90 | ||
91 | # cache eviction timer is handled | |
92 | self.reactor.pump((2,)) | |
93 | ||
94 | self.assertIsNone(cache.get(0), "cache should not have the result now") | |
95 | ||
96 | def test_cache_wait_hit(self): | |
97 | cache = self.with_cache("neutral_cache") | |
98 | ||
99 | expected_result = "howdy" | |
100 | ||
101 | wrap_d = cache.wrap(0, self.delayed_return, expected_result) | |
102 | self.assertNoResult(wrap_d) | |
103 | ||
104 | # function wakes up, returns result | |
105 | self.reactor.pump((2,)) | |
106 | ||
107 | self.assertEqual(expected_result, self.successResultOf(wrap_d)) | |
108 | ||
109 | def test_cache_wait_expire(self): | |
110 | cache = self.with_cache("medium_cache", ms=3000) | |
111 | ||
112 | expected_result = "howdy" | |
113 | ||
114 | wrap_d = cache.wrap(0, self.delayed_return, expected_result) | |
115 | self.assertNoResult(wrap_d) | |
116 | ||
117 | # stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2 | |
118 | self.reactor.pump((1, 1)) | |
119 | ||
120 | self.assertEqual(expected_result, self.successResultOf(wrap_d)) | |
121 | self.assertEqual( | |
122 | expected_result, | |
123 | self.successResultOf(cache.get(0)), | |
124 | "cache should still have the result", | |
125 | ) | |
126 | ||
127 | # (1 + 1 + 2) > 3.0, cache eviction timer is handled | |
128 | self.reactor.pump((2,)) | |
129 | ||
130 | self.assertIsNone(cache.get(0), "cache should not have the result now") |