diff --git a/src/bin/psql/startup.c b/src/bin/psql/startup.c index 5d7fe6e..55896e1 100644 --- a/src/bin/psql/startup.c +++ b/src/bin/psql/startup.c @@ -89,7 +89,6 @@ main(int argc, char *argv[]) int successResult; char *password = NULL; char *password_prompt = NULL; - bool new_pass; set_pglocale_pgservice(argv[0], PG_TEXTDOMAIN("psql")); @@ -197,50 +196,51 @@ main(int argc, char *argv[]) if (pset.getPassword == TRI_YES) password = simple_prompt(password_prompt, 100, false); - /* loop until we have a password if requested by backend */ - do - { #define PARAMS_ARRAY_SIZE 8 - const char **keywords = pg_malloc(PARAMS_ARRAY_SIZE * sizeof(*keywords)); - const char **values = pg_malloc(PARAMS_ARRAY_SIZE * sizeof(*values)); - - keywords[0] = "host"; - values[0] = options.host; - keywords[1] = "port"; - values[1] = options.port; - keywords[2] = "user"; - values[2] = options.username; - keywords[3] = "password"; - values[3] = password; - keywords[4] = "dbname"; - values[4] = (options.action == ACT_LIST_DB && - options.dbname == NULL) ? - "postgres" : options.dbname; - keywords[5] = "fallback_application_name"; - values[5] = pset.progname; - keywords[6] = "client_encoding"; - values[6] = (pset.notty || getenv("PGCLIENTENCODING")) ? NULL : "auto"; - keywords[7] = NULL; - values[7] = NULL; - - new_pass = false; - pset.db = PQconnectdbParams(keywords, values, true); - free(keywords); - free(values); - - if (PQstatus(pset.db) == CONNECTION_BAD && - PQconnectionNeedsPassword(pset.db) && - password == NULL && - pset.getPassword != TRI_NO) - { - PQfinish(pset.db); - password = simple_prompt(password_prompt, 100, false); - new_pass = true; - } - } while (new_pass); + const char **keywords = pg_malloc(PARAMS_ARRAY_SIZE * sizeof(*keywords)); + const char **values = pg_malloc(PARAMS_ARRAY_SIZE * sizeof(*values)); + + keywords[0] = "host"; + values[0] = options.host; + keywords[1] = "port"; + values[1] = options.port; + keywords[2] = "user"; + values[2] = options.username; + keywords[3] = "password"; + values[3] = password; + keywords[4] = "dbname"; + values[4] = (options.action == ACT_LIST_DB && + options.dbname == NULL) ? + "postgres" : options.dbname; + keywords[5] = "fallback_application_name"; + values[5] = pset.progname; + keywords[6] = "client_encoding"; + values[6] = (pset.notty || getenv("PGCLIENTENCODING")) ? NULL : "auto"; + keywords[7] = NULL; + values[7] = NULL; + + pset.db = PQconnectdbParams(keywords, values, true); + free(keywords); + free(values); - free(password); - free(password_prompt); + /* + * If backend asked for a password, prompt for one. + * Copy the same to pset.db to send it to the server + * using PQcontinuedbConnect() + */ + if (PQstatus(pset.db) == CONNECTION_ASKING_PASSWORD && + PQconnectionNeedsPassword(pset.db) && + password == NULL && + pset.getPassword != TRI_NO) + { + password = simple_prompt(password_prompt, 100, false); + PQcopyPassword(pset.db, password); + } + + /* + * Send the password over the existing connection + */ + PQcontinuedbConnect(pset.db); if (PQstatus(pset.db) == CONNECTION_BAD) { diff --git a/src/interfaces/libpq/exports.txt b/src/interfaces/libpq/exports.txt index 93da50d..ba0d695 100644 --- a/src/interfaces/libpq/exports.txt +++ b/src/interfaces/libpq/exports.txt @@ -165,3 +165,5 @@ lo_lseek64 162 lo_tell64 163 lo_truncate64 164 PQconninfo 165 +PQcopyPassword 166 +PQcontinuedbConnect 167 diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c index ae9dfaa..ddc9a46 100644 --- a/src/interfaces/libpq/fe-connect.c +++ b/src/interfaces/libpq/fe-connect.c @@ -466,6 +466,29 @@ PQconnectdbParams(const char *const * keywords, } /* + * PQcopyPassword + */ +void +PQcopyPassword(PGconn *conn, char *password) +{ + conn->pgpass = password; +} + +/* + * PQcontinuedbConnect + * + * Continue sending over the existing connection. Now send the password that + * the user just entered. Drive the connection protocol forward by sending server + * the password it asked for the last timewe left from connectDBComplete(). + */ +void +PQcontinuedbConnect(PGconn *conn) +{ + if(conn->status == CONNECTION_ASKING_PASSWORD) + (void) connectDBComplete(conn); +} + +/* * PQpingParams * * check server status, accepting parameters identical to PQconnectdbParams @@ -1508,7 +1531,10 @@ connectDBComplete(PGconn *conn) return 0; } break; - + case PGRES_POLLING_WAITING_PASSWORD: + conn->password_needed = true; + conn->status = CONNECTION_ASKING_PASSWORD; + return 0; default: /* Just in case we failed to set it in PQconnectPoll */ conn->status = CONNECTION_BAD; @@ -1555,6 +1581,7 @@ PQconnectPoll(PGconn *conn) PGresult *res; char sebuf[256]; int optval; + static AuthRequest areq; if (conn == NULL) return PGRES_POLLING_FAILED; @@ -1589,6 +1616,7 @@ PQconnectPoll(PGconn *conn) /* These are writing states, so we just proceed. */ case CONNECTION_STARTED: case CONNECTION_MADE: + case CONNECTION_ASKING_PASSWORD: break; /* We allow pqSetenvPoll to decide whether to proceed. */ @@ -2160,7 +2188,6 @@ keep_going: /* We will come back to here until there is char beresp; int msgLength; int avail; - AuthRequest areq; /* * Scan the message from current point (note that if we find @@ -2404,8 +2431,11 @@ keep_going: /* We will come back to here until there is */ conn->inStart = conn->inCursor; - /* Respond to the request if necessary. */ - + /* Respond to the request if necessary. + * If server asks for a password, we better prompt user for one + */ + if((conn->pgpass == NULL || conn->pgpass[0] == '\0') && (areq == AUTH_REQ_MD5 || areq == AUTH_REQ_PASSWORD)) + return PGRES_POLLING_WAITING_PASSWORD; /* * Note that conn->pghost must be non-NULL if we are going to * avoid the Kerberos code doing a hostname look-up. @@ -2442,7 +2472,34 @@ keep_going: /* We will come back to here until there is /* Look to see if we have more data yet. */ goto keep_going; } + case CONNECTION_ASKING_PASSWORD: + { + /* + * Note that conn->pghost must be non-NULL if we are going to + * avoid the Kerberos code doing a hostname look-up. + */ + if (pg_fe_sendauth(areq, conn) != STATUS_OK) + { + conn->errorMessage.len = strlen(conn->errorMessage.data); + goto error_return; + } + conn->errorMessage.len = strlen(conn->errorMessage.data); + + /* + * Just make sure that any data sent by pg_fe_sendauth is + * flushed out. Although this theoretically could block, it + * really shouldn't since we don't send large auth responses. + */ + if (pqFlush(conn)) + goto error_return; + /* + * Now go to read the server's response to password just sent + */ + conn->status = CONNECTION_AWAITING_RESPONSE; + return PGRES_POLLING_READING; + } + case CONNECTION_AUTH_OK: { /* diff --git a/src/interfaces/libpq/libpq-fe.h b/src/interfaces/libpq/libpq-fe.h index e0f4bc7..24bf1ec 100644 --- a/src/interfaces/libpq/libpq-fe.h +++ b/src/interfaces/libpq/libpq-fe.h @@ -62,7 +62,8 @@ typedef enum * backend startup. */ CONNECTION_SETENV, /* Negotiating environment. */ CONNECTION_SSL_STARTUP, /* Negotiating SSL. */ - CONNECTION_NEEDED /* Internal state: connect() needed */ + CONNECTION_NEEDED, /* Internal state: connect() needed */ + CONNECTION_ASKING_PASSWORD /* Useful to exchange password over an existing connection */ } ConnStatusType; typedef enum @@ -71,8 +72,9 @@ typedef enum PGRES_POLLING_READING, /* These two indicate that one may */ PGRES_POLLING_WRITING, /* use select before polling again. */ PGRES_POLLING_OK, - PGRES_POLLING_ACTIVE /* unused; keep for awhile for backwards + PGRES_POLLING_ACTIVE, /* unused; keep for awhile for backwards * compatibility */ + PGRES_POLLING_WAITING_PASSWORD /* Server asked for a password and getting it from the user */ } PostgresPollingStatusType; typedef enum @@ -258,6 +260,12 @@ extern PGconn *PQsetdbLogin(const char *pghost, const char *pgport, #define PQsetdb(M_PGHOST,M_PGPORT,M_PGOPT,M_PGTTY,M_DBNAME) \ PQsetdbLogin(M_PGHOST, M_PGPORT, M_PGOPT, M_PGTTY, M_DBNAME, NULL, NULL) +/* copy the password to the conn */ +extern void PQcopyPassword(PGconn *conn, char *password); + +/* Send the password user just entered to the server over an existing connection */ +extern void PQcontinuedbConnect(PGconn *conn); + /* close the current connection and free the PGconn data structure */ extern void PQfinish(PGconn *conn);